kevinwang676 commited on
Commit
0211c0b
Β·
verified Β·
1 Parent(s): 3043352

Update runtime/python/grpc/server.py

Browse files
Files changed (1) hide show
  1. runtime/python/grpc/server.py +83 -57
runtime/python/grpc/server.py CHANGED
@@ -59,35 +59,66 @@ def _yield_audio(model_output):
59
  resp = cosyvoice_pb2.Response(tts_audio=pcm16.tobytes())
60
  yield resp
61
 
62
- import urllib.parse
 
63
 
64
  def _load_prompt_from_url(url: str, target_sr: int = 16_000) -> torch.Tensor:
65
- """
66
- Download *url* (wav / mp3 / flac) ➜ mono torch.FloatTensor [1,β€―T] @ target_sr.
67
- The temp file is removed before return.
68
- """
 
69
  resp = requests.get(url, timeout=10)
70
- resp.raise_for_status()
 
 
71
 
72
- # keep the original extension so torchaudio picks the right decoder
73
- #ext = os.path.splitext(urllib.parse.urlparse(url).path)[1] or ".tmp"
74
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
 
 
 
 
 
 
 
 
 
 
 
75
  f.write(resp.content)
76
- tmp_path = f.name
77
 
 
78
  try:
79
- wav, sr = torchaudio.load(tmp_path) # handles wav / mp3 / flac
80
- if wav.ndim > 1:
81
- wav = wav.mean(dim=0, keepdim=True)
82
- if sr != target_sr:
83
- wav = torchaudio.functional.resample(wav, sr, target_sr)
84
- return wav # [1,β€―T] float32 in [-1,1]
 
 
 
 
 
 
 
 
 
85
  finally:
86
- try:
87
- os.remove(tmp_path)
88
- except Exception as e:
89
- logging.warning("Could not delete temp file %s: %s", tmp_path, e)
90
 
 
 
 
 
 
 
 
 
 
91
  # ────────────────────────────────────────────────────────────────────────────────
92
  # gRPC service
93
  # ────────────────────────────────────────────────────────────────────────────────
@@ -182,55 +213,50 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
182
  return
183
 
184
 
185
- # 4. Instruction‑TTS (two flavours)
186
  if request.HasField("instruct_request"):
 
187
  ir = request.instruct_request
188
 
189
- # ──────────────────────────────────────────────────────────────────
190
- # 4‑a) instruct‑2 (has prompt_audio β†’ bytes OR S3 URL)
191
- # ──────────────────────────────────────────────────────────────────
192
- if ir.HasField("prompt_audio"):
193
- logging.info("Received instruct‑2 inference request")
 
194
 
195
- tmp_path = None
196
- try:
197
- if ir.prompt_audio.startswith(b'http'):
198
- prompt = _load_prompt_from_url(ir.prompt_audio.decode('utf‑8'))
199
- else:
200
- # legacy raw‑bytes payload
201
- prompt = _bytes_to_tensor(ir.prompt_audio)
202
 
203
- speed = getattr(ir, "speed", 1.0)
204
- mo = self.cosyvoice.inference_instruct2(
205
- ir.tts_text,
206
- ir.instruct_text,
207
- prompt,
208
- stream=False,
209
- speed=speed
210
- )
211
 
212
- finally:
213
- if tmp_path and os.path.exists(tmp_path):
214
- try:
215
- os.remove(tmp_path)
216
- except Exception as e:
217
- logging.warning("Could not remove temp file %s: %s",
218
- tmp_path, e)
219
 
220
- # ──────────────────────────────────────────────────────────────────
221
- # 4‑b) classic instruct (speaker‑ID, no prompt audio)
222
- # ──────────────────────────────────────────────────────────────────
223
  else:
224
- logging.info("Received instruct inference request")
225
- mo = self.cosyvoice.inference_instruct(
226
- ir.tts_text,
227
- ir.spk_id,
228
- ir.instruct_text
229
- )
 
 
 
 
230
 
231
  yield from _yield_audio(mo)
232
  return
233
 
 
234
  # unknown request type
235
  context.abort(grpc.StatusCode.INVALID_ARGUMENT,
236
  "Unsupported request type in oneof field.")
 
59
  resp = cosyvoice_pb2.Response(tts_audio=pcm16.tobytes())
60
  yield resp
61
 
62
+ import os, io, tempfile, requests, torch, torchaudio
63
+ from urllib.parse import urlparse
64
 
65
  def _load_prompt_from_url(url: str, target_sr: int = 16_000) -> torch.Tensor:
66
+ """Download an audio file from ``url`` (wav / mp3 / flac / ogg …),
67
+ convert it to mono, resample to ``target_sr`` if necessary,
68
+ and return a 1Γ—T float‑tensor in the range ‑1…1."""
69
+
70
+ # ─── 1. Download ────────────────────────────────────────────────────────────
71
  resp = requests.get(url, timeout=10)
72
+ if resp.status_code != 200:
73
+ raise HTTPException(status_code=400,
74
+ detail=f"Failed to download audio from URL: {url}")
75
 
76
+ # Infer extension from URL *or* Content‑Type header
77
+ ext = os.path.splitext(urlparse(url).path)[1].lower()
78
+ if not ext and 'content-type' in resp.headers:
79
+ mime = resp.headers['content-type'].split(';')[0].strip()
80
+ ext = {
81
+ 'audio/mpeg': '.mp3',
82
+ 'audio/wav': '.wav',
83
+ 'audio/x-wav': '.wav',
84
+ 'audio/flac': '.flac',
85
+ 'audio/ogg': '.ogg',
86
+ 'audio/x-m4a': '.m4a',
87
+ }.get(mime, '.audio') # generic fallback
88
+
89
+ with tempfile.NamedTemporaryFile(suffix=ext or '.audio', delete=False) as f:
90
  f.write(resp.content)
91
+ temp_path = f.name
92
 
93
+ # ─── 2. Decode (torchaudio first, pydub fallback) ──────────────────────────
94
  try:
95
+ # Let torchaudio pick the right backend automatically
96
+ speech, sample_rate = torchaudio.load(temp_path)
97
+ except Exception:
98
+ # Fallback that works as long as ffmpeg is present
99
+ from pydub import AudioSegment
100
+ import numpy as np
101
+
102
+ seg = AudioSegment.from_file(temp_path) # any ffmpeg‑supported format
103
+ seg = seg.set_channels(1) # force mono
104
+ sample_rate = seg.frame_rate
105
+ np_audio = np.array(seg.get_array_of_samples()).astype(np.float32)
106
+ # normalise to βˆ’1…1 based on sample width
107
+ np_audio /= float(1 << (8 * seg.sample_width - 1))
108
+ speech = torch.from_numpy(np_audio).unsqueeze(0)
109
+
110
  finally:
111
+ os.unlink(temp_path)
 
 
 
112
 
113
+ # ─── 3. Ensure mono + correct sample‑rate ──────────────────────────────────
114
+ if speech.dim() > 1 and speech.size(0) > 1:
115
+ speech = speech.mean(dim=0, keepdim=True) # average to mono
116
+
117
+ if sample_rate != target_sr:
118
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate,
119
+ new_freq=target_sr)(speech)
120
+ return speech
121
+
122
  # ────────────────────────────────────────────────────────────────────────────────
123
  # gRPC service
124
  # ────────────────────────────────────────────────────────────────────────────────
 
213
  return
214
 
215
 
216
+ # 4. Instruct‑2 (CosyVoice2 supports this variant only)
217
  if request.HasField("instruct_request"):
218
+
219
  ir = request.instruct_request
220
 
221
+ # ---- require that the descriptor contains the field -------------------
222
+ if 'prompt_audio' not in ir.DESCRIPTOR.fields_by_name:
223
+ context.abort(
224
+ grpc.StatusCode.INVALID_ARGUMENT,
225
+ "Server expects instruct‑2 proto with a 'prompt_audio' field."
226
+ )
227
 
228
+ # ---- make sure it is non‑empty (no HasField for proto3 scalars) -------
229
+ if len(ir.prompt_audio) == 0:
230
+ context.abort(
231
+ grpc.StatusCode.INVALID_ARGUMENT,
232
+ "'prompt_audio' must not be empty for instruct‑2 requests."
233
+ )
 
234
 
235
+ logging.info("Received instruct‑2 inference request")
 
 
 
 
 
 
 
236
 
237
+ # convert to bytes no matter what scalar type the proto uses
238
+ pa_bytes = (ir.prompt_audio.encode('utf-8') if isinstance(ir.prompt_audio, str)
239
+ else ir.prompt_audio)
 
 
 
 
240
 
241
+ # URL vs raw bytes
242
+ if pa_bytes.startswith(b"http"):
243
+ prompt = _load_prompt_from_url(pa_bytes.decode('utf-8'))
244
  else:
245
+ prompt = _bytes_to_tensor(pa_bytes)
246
+
247
+ speed = getattr(ir, "speed", 1.0)
248
+ mo = self.cosyvoice.inference_instruct2(
249
+ ir.tts_text,
250
+ ir.instruct_text,
251
+ prompt,
252
+ stream=False,
253
+ speed=speed,
254
+ )
255
 
256
  yield from _yield_audio(mo)
257
  return
258
 
259
+
260
  # unknown request type
261
  context.abort(grpc.StatusCode.INVALID_ARGUMENT,
262
  "Unsupported request type in oneof field.")