jan-hq commited on
Commit
b46f992
1 Parent(s): 68fc348

Upload folder using huggingface_hub

Browse files
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import argparse
2
  parser = argparse.ArgumentParser(description="WhisperVQ Application")
3
  parser.add_argument('--log-path', type=str,
4
  default='whisper.log', help='The log file path')
@@ -6,32 +6,22 @@ parser.add_argument('--log-level', type=str, default='INFO',
6
  choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'TRACE'], help='The log level')
7
  parser.add_argument('--port', type=int, default=3348,
8
  help='The port to run the WhisperVQ app on')
 
 
9
  parser.add_argument('--package-dir', type=str, default="",
10
  help='The package-dir to be extended to sys.path')
11
  args = parser.parse_args()
12
- import sys
13
- sys.path.insert(0, args.environment)
14
- import tempfile
15
- from typing import Tuple
16
- from enum import Enum
17
- import io
18
  import logging
19
- from custom_component import CustomRQBottleneckTransformer
20
- from whisperspeech.vq_stoks import RQBottleneckTransformer
21
- from huggingface_hub import hf_hub_download
22
  import uvicorn
23
- from transformers import WhisperModel, WhisperProcessor
24
- from fastapi.responses import JSONResponse
25
- from fastapi import FastAPI, File, UploadFile, HTTPException
26
  from contextlib import asynccontextmanager
27
- import torchaudio
28
- import torch
29
  import os
30
  import time
31
  import psutil
32
  import threading
33
-
34
-
35
  logging.basicConfig(level=args.log_level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
36
  handlers=[
37
  logging.FileHandler(args.log_path),
@@ -39,200 +29,24 @@ logging.basicConfig(level=args.log_level, format='%(asctime)s - %(name)s - %(lev
39
  ])
40
  logger = logging.getLogger(__name__)
41
 
42
- os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use the first GPU
43
-
44
 
45
- device = "cuda" if torch.cuda.is_available() else "cpu"
46
- if not os.path.exists(os.path.dirname(os.path.realpath(__file__))+"/whisper-vq-stoks-v3-7lang-fixed.model"):
47
- hf_hub_download(
48
- repo_id="jan-hq/WhisperVQ",
49
- filename="whisper-vq-stoks-v3-7lang-fixed.model",
50
- local_dir=".",
51
- )
52
- vq_model = CustomRQBottleneckTransformer.load_vq_only(
53
- os.path.dirname(os.path.realpath(__file__)) +
54
- "/whisper-vq-stoks-v3-7lang-fixed.model"
55
- ).to(device)
56
- vq_model.load_encoder(device)
57
- vq_model.eval()
58
 
 
 
59
 
60
  @asynccontextmanager
61
  async def lifespan(app: FastAPI):
62
-
 
 
63
  yield
64
  # on shutdown
65
 
66
-
67
- # vq_model = torch.compile(vq_model)
68
-
69
-
70
- class AudioFormat(str, Enum):
71
- WAV = "wav" # Supported by both backends
72
- MP3 = "mp3" # Supported by ffmpeg
73
- FLAC = "flac" # Supported by both
74
- AAC = "aac" # Supported by ffmpeg
75
- OGG = "ogg" # Supported by ffmpeg
76
- OPUS = "opus" # Supported by ffmpeg
77
- PCM = "pcm" # Raw PCM data
78
-
79
-
80
- # Format to backend mapping
81
- FORMAT_BACKENDS = {
82
- AudioFormat.WAV: ["soundfile", "ffmpeg"],
83
- AudioFormat.MP3: ["ffmpeg"],
84
- AudioFormat.FLAC: ["soundfile", "ffmpeg"],
85
- AudioFormat.AAC: ["ffmpeg"],
86
- AudioFormat.OGG: ["ffmpeg"],
87
- AudioFormat.OPUS: ["ffmpeg"],
88
- AudioFormat.PCM: ["soundfile"]
89
- }
90
-
91
-
92
- class AudioProcessor:
93
- def __init__(self):
94
- self.available_backends = torchaudio.list_audio_backends()
95
- logger.info(f"Available backends: {self.available_backends}")
96
-
97
- # Verify ffmpeg support
98
- self.has_ffmpeg = "ffmpeg" in self.available_backends
99
- if not self.has_ffmpeg:
100
- logger.warning(
101
- "FFMPEG backend not available. Some formats may not be supported")
102
-
103
- def _get_best_backend(self, format: AudioFormat) -> str:
104
- """Determine the best backend for the given format"""
105
- supported_backends = FORMAT_BACKENDS[format]
106
- for backend in supported_backends:
107
- if backend in self.available_backends:
108
- return backend
109
- raise ValueError(f"No available backend supports format {format}")
110
-
111
- async def load_audio(
112
- self,
113
- file_obj: bytes,
114
- format: AudioFormat,
115
- target_sr: int = 16000
116
- ) -> Tuple[torch.Tensor, int]:
117
- """
118
- Load audio from bytes object with format handling
119
-
120
- Args:
121
- file_obj: Audio file bytes
122
- format: Audio format enum
123
- target_sr: Target sample rate (default: 16000)
124
-
125
- Returns:
126
- Tuple[torch.Tensor, int]: Audio tensor and sample rate
127
- """
128
- try:
129
- # Get appropriate backend
130
- backend = self._get_best_backend(format)
131
- torchaudio.set_audio_backend(backend)
132
- logger.info(f"Using {backend} backend for {format} format")
133
-
134
- if format == AudioFormat.PCM:
135
- # Handle raw PCM
136
- wav = torch.frombuffer(file_obj, dtype=torch.int16)
137
- wav = wav.float() / 32768.0 # Normalize to [-1, 1]
138
- wav = wav.unsqueeze(0) # Add channel dimension
139
- sr = target_sr
140
- else:
141
- # For formats that might need ffmpeg processing
142
- if os.name == "nt": # for windows
143
- wav, sr = torchaudio.load(io.BytesIO(file_obj))
144
- else:
145
- with tempfile.NamedTemporaryFile(suffix=f".{format}") as temp_file:
146
- # Write bytes to temporary file
147
- temp_file.write(file_obj)
148
- temp_file.flush()
149
-
150
- # Load audio
151
- wav, sr = torchaudio.load(temp_file.name)
152
-
153
- # Convert to mono if stereo
154
- if wav.shape[0] > 1:
155
- wav = torch.mean(wav, dim=0, keepdim=True)
156
-
157
- # Resample if needed
158
- if sr != target_sr:
159
- wav = torchaudio.functional.resample(wav, sr, target_sr)
160
- sr = target_sr
161
-
162
- return wav, sr
163
-
164
- except Exception as e:
165
- logger.error(f"Error loading audio: {e}")
166
- raise HTTPException(
167
- status_code=400,
168
- detail=f"Error processing {format} audio: {str(e)}"
169
- )
170
-
171
- def get_format_info(self) -> dict:
172
- """Get information about supported formats"""
173
- supported_formats = {}
174
- for format in AudioFormat:
175
- try:
176
- backend = self._get_best_backend(format)
177
- supported_formats[format] = {
178
- "supported": True,
179
- "backend": backend
180
- }
181
- except ValueError:
182
- supported_formats[format] = {
183
- "supported": False,
184
- "backend": None
185
- }
186
- return supported_formats
187
-
188
-
189
- audio_processor = AudioProcessor()
190
-
191
  app = FastAPI(lifespan=lifespan)
192
 
193
-
194
- @app.get("/supported_formats")
195
- async def get_supported_formats():
196
- """Endpoint to check supported formats"""
197
- return audio_processor.get_format_info()
198
-
199
-
200
- @app.post("/tokenize/{format}")
201
- async def tokenize_audio(format: AudioFormat = "wav", file: UploadFile = File(...)):
202
- try:
203
- # Read file
204
- file_obj = await file.read()
205
-
206
- # Load and process audio
207
- wav, sr = await audio_processor.load_audio(file_obj, format)
208
-
209
- # Ensure we're using CUDA if available
210
- device = "cuda" if torch.cuda.is_available() else "cpu"
211
- wav = wav.to(device)
212
-
213
- # Generate tokens
214
- with torch.no_grad():
215
- codes = vq_model.encode_audio(wav)
216
- codes = codes[0].cpu().tolist()
217
-
218
- # Format result
219
- result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
220
-
221
- return JSONResponse(content={
222
- "model_name": "whisper-vq-stoks-v3-7lang-fixed.model",
223
- "tokens": f'<|sound_start|>{result}<|sound_end|>',
224
- "format": format,
225
- "sample_rate": sr,
226
- "backend_used": audio_processor._get_best_backend(format)
227
- })
228
-
229
- except Exception as e:
230
- logger.error(f"Error processing request: {e}")
231
- raise HTTPException(
232
- status_code=500,
233
- detail=f"Error processing request: {str(e)}"
234
- )
235
-
236
 
237
  def self_terminate():
238
  time.sleep(1)
@@ -240,8 +54,8 @@ def self_terminate():
240
  parent.kill()
241
 
242
 
243
- @app.post("/kill")
244
- async def kill():
245
  threading.Thread(target=self_terminate, daemon=True).start()
246
  return {"success": True}
247
 
@@ -263,8 +77,7 @@ if __name__ == "__main__":
263
  LOGGING_CONFIG["loggers"]["uvicorn.access"]["level"] = args.log_level
264
 
265
  # Print supported formats at startup
266
- processor = AudioProcessor()
267
- format_info = processor.get_format_info()
268
  logger.info("Supported formats:")
269
  for format, info in format_info.items():
270
  logger.info(f"{format}: {info}")
 
1
+ import argparse, os,sys
2
  parser = argparse.ArgumentParser(description="WhisperVQ Application")
3
  parser.add_argument('--log-path', type=str,
4
  default='whisper.log', help='The log file path')
 
6
  choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'TRACE'], help='The log level')
7
  parser.add_argument('--port', type=int, default=3348,
8
  help='The port to run the WhisperVQ app on')
9
+ parser.add_argument('--device-id', type=str, default="0",
10
+ help='The port to run the WhisperVQ app on')
11
  parser.add_argument('--package-dir', type=str, default="",
12
  help='The package-dir to be extended to sys.path')
13
  args = parser.parse_args()
14
+ sys.path.insert(0, args.package_dir)
15
+ os.environ["CUDA_VISIBLE_DEVICES"] =args.device_id # Use the first Nvidia GPU
16
+
 
 
 
17
  import logging
 
 
 
18
  import uvicorn
19
+ from fastapi import FastAPI
 
 
20
  from contextlib import asynccontextmanager
 
 
21
  import os
22
  import time
23
  import psutil
24
  import threading
 
 
25
  logging.basicConfig(level=args.log_level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
26
  handlers=[
27
  logging.FileHandler(args.log_path),
 
29
  ])
30
  logger = logging.getLogger(__name__)
31
 
 
 
32
 
33
+ # after set up logger we can import and use services
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ from services.AudioTokenizerService import get_audio_tokenizer_service
36
+ from routes.AudioTokenizerRoute import audio_tokenizer_router
37
 
38
  @asynccontextmanager
39
  async def lifespan(app: FastAPI):
40
+
41
+ # on startup
42
+ get_audio_tokenizer_service()
43
  yield
44
  # on shutdown
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  app = FastAPI(lifespan=lifespan)
47
 
48
+ # include the routes
49
+ app.include_router(audio_tokenizer_router)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def self_terminate():
52
  time.sleep(1)
 
54
  parent.kill()
55
 
56
 
57
+ @app.delete("/destroy")
58
+ async def destroy():
59
  threading.Thread(target=self_terminate, daemon=True).start()
60
  return {"success": True}
61
 
 
77
  LOGGING_CONFIG["loggers"]["uvicorn.access"]["level"] = args.log_level
78
 
79
  # Print supported formats at startup
80
+ format_info = get_audio_tokenizer_service().get_format_info()
 
81
  logger.info("Supported formats:")
82
  for format, info in format_info.items():
83
  logger.info(f"{format}: {info}")
models/audio.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from enum import Enum
3
+
4
+ class AudioFormat(str, Enum):
5
+ WAV = "wav" # Supported by both backends
6
+ MP3 = "mp3" # Supported by ffmpeg
7
+ FLAC = "flac" # Supported by both
8
+ AAC = "aac" # Supported by ffmpeg
9
+ OGG = "ogg" # Supported by ffmpeg
10
+ OPUS = "opus" # Supported by ffmpeg
11
+ PCM = "pcm" # Raw PCM data
12
+
13
+ # Format to backend mapping
14
+ FORMAT_BACKENDS = {
15
+ AudioFormat.WAV: ["soundfile", "ffmpeg"],
16
+ AudioFormat.MP3: ["ffmpeg"],
17
+ AudioFormat.FLAC: ["soundfile", "ffmpeg"],
18
+ AudioFormat.AAC: ["ffmpeg"],
19
+ AudioFormat.OGG: ["ffmpeg"],
20
+ AudioFormat.OPUS: ["ffmpeg"],
21
+ AudioFormat.PCM: ["soundfile"]
22
+ }
routes/AudioTokenizerRoute.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from services.AudioTokenizerService import get_audio_tokenizer_service
2
+ from fastapi import APIRouter, Depends, HTTPException, status
3
+ from fastapi import File, UploadFile
4
+ from models.audio import AudioFormat, FORMAT_BACKENDS
5
+
6
+ audio_tokenizer_router = APIRouter(
7
+ prefix="/tokenize", tags=["audio"])
8
+
9
+
10
+ @audio_tokenizer_router.post("/{format}")
11
+ async def tokenize_audio(format: AudioFormat = "wav", file: UploadFile = File(...)):
12
+ file_obj = await file.read()
13
+ get_audio_tokenizer_service().tokenize(file_obj, format)
14
+ return get_audio_tokenizer_service().tokenize(file_obj, format)
15
+
16
+
17
+ @audio_tokenizer_router.get("/supported_formats")
18
+ async def get_supported_formats():
19
+ return get_audio_tokenizer_service().get_format_info()
services/AudioTokenizerService.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from huggingface_hub import hf_hub_download
4
+ from models.audio import AudioFormat, FORMAT_BACKENDS
5
+ import tempfile
6
+ import logging
7
+ import torchaudio
8
+ from fastapi import HTTPException
9
+ from fastapi.responses import JSONResponse
10
+ import torch
11
+ from typing import Tuple
12
+ from utils.custom_component import CustomRQBottleneckTransformer
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class AudioTokenizerService:
17
+ def __init__(self):
18
+ self.available_backends = torchaudio.list_audio_backends()
19
+ logger.info(f"Available backends: {self.available_backends}")
20
+ main_directory = os.path.dirname(
21
+ os.path.dirname(os.path.realpath(__file__)))
22
+
23
+ # Verify ffmpeg support
24
+ self.has_ffmpeg = "ffmpeg" in self.available_backends
25
+ if not self.has_ffmpeg:
26
+ logger.warning(
27
+ "FFMPEG backend not available. Some formats may not be supported")
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ if not os.path.exists(main_directory+"/whisper-vq-stoks-v3-7lang-fixed.model"):
30
+ hf_hub_download(
31
+ repo_id="jan-hq/WhisperVQ",
32
+ filename="whisper-vq-stoks-v3-7lang-fixed.model",
33
+ local_dir=main_directory,
34
+ )
35
+ self.vq_model = CustomRQBottleneckTransformer.load_vq_only(
36
+ main_directory +
37
+ "/whisper-vq-stoks-v3-7lang-fixed.model"
38
+ ).to(device)
39
+ self.vq_model.load_encoder(device)
40
+ self.vq_model.eval()
41
+ # vq_model = torch.compile(vq_model)
42
+
43
+ def _get_best_backend(self, format: AudioFormat) -> str:
44
+ """Determine the best backend for the given format"""
45
+ supported_backends = FORMAT_BACKENDS[format]
46
+ for backend in supported_backends:
47
+ if backend in self.available_backends:
48
+ return backend
49
+ raise ValueError(f"No available backend supports format {format}")
50
+
51
+ def load_audio(
52
+ self,
53
+ file_obj: bytes,
54
+ format: AudioFormat,
55
+ target_sr: int = 16000
56
+ ) -> Tuple[torch.Tensor, int]:
57
+ """
58
+ Load audio from bytes object with format handling
59
+
60
+ Args:
61
+ file_obj: Audio file bytes
62
+ format: Audio format enum
63
+ target_sr: Target sample rate (default: 16000)
64
+
65
+ Returns:
66
+ Tuple[torch.Tensor, int]: Audio tensor and sample rate
67
+ """
68
+ try:
69
+ # Get appropriate backend
70
+ backend = self._get_best_backend(format)
71
+ torchaudio.set_audio_backend(backend)
72
+ logger.info(f"Using {backend} backend for {format} format")
73
+
74
+ if format == AudioFormat.PCM:
75
+ # Handle raw PCM
76
+ wav = torch.frombuffer(file_obj, dtype=torch.int16)
77
+ wav = wav.float() / 32768.0 # Normalize to [-1, 1]
78
+ wav = wav.unsqueeze(0) # Add channel dimension
79
+ sr = target_sr
80
+ else:
81
+ # For formats that might need ffmpeg processing
82
+ if os.name == "nt": # for windows
83
+ wav, sr = torchaudio.load(io.BytesIO(file_obj))
84
+ else:
85
+ with tempfile.NamedTemporaryFile(suffix=f".{format}") as temp_file:
86
+ # Write bytes to temporary file
87
+ temp_file.write(file_obj)
88
+ temp_file.flush()
89
+
90
+ # Load audio
91
+ wav, sr = torchaudio.load(temp_file.name)
92
+
93
+ # Convert to mono if stereo
94
+ if wav.shape[0] > 1:
95
+ wav = torch.mean(wav, dim=0, keepdim=True)
96
+
97
+ # Resample if needed
98
+ if sr != target_sr:
99
+ wav = torchaudio.functional.resample(wav, sr, target_sr)
100
+ sr = target_sr
101
+
102
+ return wav, sr
103
+
104
+ except Exception as e:
105
+ logger.error(f"Error loading audio: {e}")
106
+ raise HTTPException(
107
+ status_code=400,
108
+ detail=f"Error processing {format} audio: {str(e)}"
109
+ )
110
+
111
+ def get_format_info(self) -> dict:
112
+ """Get information about supported formats"""
113
+ supported_formats = {}
114
+ for format in AudioFormat:
115
+ try:
116
+ backend = self._get_best_backend(format)
117
+ supported_formats[format] = {
118
+ "supported": True,
119
+ "backend": backend
120
+ }
121
+ except ValueError:
122
+ supported_formats[format] = {
123
+ "supported": False,
124
+ "backend": None
125
+ }
126
+ return supported_formats
127
+
128
+ def tokenize(self, audio_data: bytes, format: AudioFormat = "wav"):
129
+ try:
130
+ wav, sr = self.load_audio(audio_data, format)
131
+
132
+ # Ensure we're using CUDA if available
133
+ device = "cuda" if torch.cuda.is_available() else "cpu"
134
+ wav = wav.to(device)
135
+
136
+ # Generate tokens
137
+ with torch.no_grad():
138
+ codes = self.vq_model.encode_audio(wav)
139
+ codes = codes[0].cpu().tolist()
140
+
141
+ # Format result
142
+ result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
143
+
144
+ return JSONResponse(content={
145
+ "model_name": "whisper-vq-stoks-v3-7lang-fixed.model",
146
+ "tokens": f'<|sound_start|>{result}<|sound_end|>',
147
+ "format": format,
148
+ "sample_rate": sr,
149
+ "backend_used": self._get_best_backend(format)
150
+ })
151
+
152
+ except Exception as e:
153
+ logger.error(f"Error processing request: {e}")
154
+ raise HTTPException(
155
+ status_code=500,
156
+ detail=f"Error processing request: {str(e)}"
157
+ )
158
+
159
+
160
+ _audio_tokenizer_service = None
161
+
162
+
163
+ def get_audio_tokenizer_service():
164
+ global _audio_tokenizer_service
165
+ if _audio_tokenizer_service is None:
166
+ _audio_tokenizer_service = AudioTokenizerService()
167
+ return _audio_tokenizer_service
utils/custom_component.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import whisper
4
+ from whisper.model import AudioEncoder, ModelDimensions
5
+ from typing import Dict, Optional
6
+ from whisperspeech.vq_stoks import RQBottleneckTransformer, Tunables
7
+ from huggingface_hub import hf_hub_download
8
+ import torch.nn.functional as F
9
+ import os
10
+ from typing import List, Optional, Union
11
+ import io
12
+ import urllib
13
+ from tqdm import tqdm
14
+ import torchaudio
15
+
16
+ _HF_MODELS = {
17
+ "medium": "https://huggingface.co/jan-hq/WhisperVQ/resolve/main/medium_encoder_only.pt",
18
+ }
19
+
20
+
21
+ def available_models() -> List[str]:
22
+ """Returns the names of available models"""
23
+ return list(_HF_MODELS.keys())
24
+
25
+
26
+ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
27
+ os.makedirs(root, exist_ok=True)
28
+
29
+ expected_sha256 = url.split("/")[-2]
30
+ download_target = os.path.join(root, os.path.basename(url))
31
+
32
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
33
+ raise RuntimeError(
34
+ f"{download_target} exists and is not a regular file")
35
+
36
+ if os.path.isfile(download_target):
37
+ with open(download_target, "rb") as f:
38
+ model_bytes = f.read()
39
+ return model_bytes if in_memory else download_target
40
+ import ssl
41
+ ssl._create_default_https_context = ssl._create_unverified_context
42
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
43
+ with tqdm(
44
+ total=int(source.info().get("Content-Length")),
45
+ ncols=80,
46
+ unit="iB",
47
+ unit_scale=True,
48
+ unit_divisor=1024,
49
+ ) as loop:
50
+ while True:
51
+ buffer = source.read(8192)
52
+ if not buffer:
53
+ break
54
+
55
+ output.write(buffer)
56
+ loop.update(len(buffer))
57
+
58
+ model_bytes = open(download_target, "rb").read()
59
+ return model_bytes if in_memory else download_target
60
+
61
+
62
+ class CustomWhisperEncoder(nn.Module):
63
+ """
64
+ Lightweight wrapper that only loads the AudioEncoder part of Whisper
65
+ """
66
+
67
+ def __init__(self, name: str, device: str = None, download_root: str = None, in_memory: bool = False,):
68
+ super().__init__()
69
+ if device is None:
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
+ if download_root is None:
72
+ default = os.path.join(os.path.expanduser("~"), ".cache")
73
+ # os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
74
+ download_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
75
+
76
+ if name in _HF_MODELS:
77
+ checkpoint_file = _download(
78
+ _HF_MODELS[name], download_root, in_memory)
79
+ elif os.path.isfile(name):
80
+ checkpoint_file = open(name, "rb").read() if in_memory else name
81
+ else:
82
+ raise RuntimeError(
83
+ f"Model {name} not found; available models = {available_models()}"
84
+ )
85
+
86
+ # Load weights
87
+ with (
88
+ io.BytesIO(checkpoint_file) if in_memory else open(
89
+ checkpoint_file, "rb")
90
+ ) as fp:
91
+ checkpoint = torch.load(fp, map_location=device)
92
+ del checkpoint_file
93
+ dims = ModelDimensions(**checkpoint["dims"])
94
+ self.encoder = AudioEncoder(
95
+ dims.n_mels,
96
+ dims.n_audio_ctx,
97
+ dims.n_audio_state,
98
+ dims.n_audio_head,
99
+ dims.n_audio_layer,
100
+ )
101
+
102
+ self.encoder.load_state_dict(checkpoint["model_state_dict"])
103
+
104
+ if device:
105
+ self.to(device)
106
+
107
+ self.eval()
108
+
109
+ def forward(self, mel: torch.Tensor):
110
+ return self.encoder(mel)
111
+
112
+
113
+ class CustomRQBottleneckTransformer(RQBottleneckTransformer):
114
+ def __init__(self, *args, **kwargs):
115
+ super().__init__(*args, **kwargs)
116
+
117
+ @classmethod
118
+ def load_vq_only(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model",
119
+ repo_id=None, filename=None, local_filename=None):
120
+ if repo_id is None and filename is None and local_filename is None:
121
+ if ":" in ref:
122
+ repo_id, filename = ref.split(":", 1)
123
+ else:
124
+ local_filename = ref
125
+ if not local_filename:
126
+ local_filename = hf_hub_download(
127
+ repo_id=repo_id, filename=filename)
128
+
129
+ # Load the spec
130
+ spec = torch.load(local_filename)
131
+
132
+ # Create instance with minimal required components
133
+ instance = cls(**spec['config'], tunables=Tunables(**
134
+ Tunables.upgrade(spec.get('tunables', {}))))
135
+
136
+ # Load only necessary state dict entries
137
+ required_components = {
138
+ 'rq', 'mlp', 'mlp_ln'
139
+ }
140
+ filtered_state_dict = {
141
+ k: v for k, v in spec['state_dict'].items()
142
+ if any(k.startswith(comp) for comp in required_components)
143
+ }
144
+
145
+ instance.load_state_dict(filtered_state_dict, strict=False)
146
+ instance.eval()
147
+ return instance
148
+
149
+ def load_encoder(self, device=None):
150
+ if self.whmodel is not None:
151
+ return
152
+ device = device or self.device
153
+ # Use our custom encoder-only model
154
+ if self.whmodel is None:
155
+ encoder = CustomWhisperEncoder(
156
+ self.whisper_model_name, device=device)
157
+ self.whmodel = encoder
158
+ multilingual = not self.whisper_model_name.endswith('.en')
159
+ self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual)
160
+
161
+ def optimzed_encode_mel(self, mel):
162
+ assert len(
163
+ mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)"
164
+ self.load_encoder()
165
+ n = mel.shape[-1]
166
+ if n > whisper.audio.N_FRAMES:
167
+ padding = 0
168
+ padded = mel[:, :, :whisper.audio.N_FRAMES]
169
+ else:
170
+ padding = -n % whisper.audio.N_FRAMES
171
+ padded = F.pad(mel, (0, padding), value=-1.5)
172
+ # .to(self.whmodel[0].device))#[:,:n//2]
173
+ embs = self.whmodel.encoder(padded)
174
+ stoks = self.quantize(embs)
175
+ if self.tunables.mask_embs:
176
+ return stoks[:, :n//2//self.downsample]
177
+ else:
178
+ return stoks
179
+ # overide
180
+
181
+ def encode_audio(self, audio):
182
+ if isinstance(audio, str):
183
+ x, sr = torchaudio.load(audio)
184
+ x = torchaudio.transforms.Resample(sr, 16000)(x)[0]
185
+ audio = x.unsqueeze(0)
186
+ return self.optimzed_encode_mel(self.log_mel_spectrogram(audio).to(self.device))
187
+
188
+
189
+ if __name__ == "__main__":
190
+ # Load the model
191
+ vqmodel = CustomRQBottleneckTransformer.load_vq_only(
192
+ "whisper-vq-stoks-v3-7lang-fixed.model"
193
+ ).to("cuda")
194
+ vqmodel.load_encoder('cuda')
195
+ vqmodel.eval()