david commited on
Commit
9bdac3d
·
1 Parent(s): 90a29dd

add fastapi to host websocket

Browse files
frontend/index.html CHANGED
@@ -10,6 +10,52 @@
10
  </head>
11
  <body>
12
  <div id="app"></div>
13
-
14
  </body>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  </html>
 
10
  </head>
11
  <body>
12
  <div id="app"></div>
 
13
  </body>
14
+ <script>
15
+ async function startRecording() {
16
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
17
+ const audioContext = new AudioContext({ sampleRate: 16000 });
18
+ const source = audioContext.createMediaStreamSource(stream);
19
+ const processor = audioContext.createScriptProcessor(4096, 1, 1);
20
+
21
+ const wsUrl = "ws://localhost:9090/ws?from=zh&to=en";
22
+ ws = new WebSocket(wsUrl);
23
+
24
+ ws.binaryType = "arraybuffer";
25
+
26
+ ws.onopen = () => {
27
+ console.log("WebSocket opened");
28
+ source.connect(processor);
29
+ processor.connect(audioContext.destination);
30
+
31
+ processor.onaudioprocess = (e) => {
32
+ const input = e.inputBuffer.getChannelData(0);
33
+ const buffer = new Int16Array(input.length);
34
+ for (let i = 0; i < input.length; i++) {
35
+ buffer[i] = Math.max(-1, Math.min(1, input[i])) * 0x7FFF;
36
+ }
37
+ ws.send(buffer);
38
+ };
39
+ };
40
+
41
+ ws.onmessage = (event) => {
42
+ try {
43
+ const msg = JSON.parse(event.data);
44
+ if (msg.result) {
45
+ addTranslation(msg.result);
46
+ }
47
+ } catch (e) {
48
+ console.error("Parse error:", e);
49
+ }
50
+ };
51
+
52
+ ws.onerror = (e) => console.error("WebSocket error:", e);
53
+ ws.onclose = () => {
54
+ console.log("WebSocket closed");
55
+ processor.disconnect();
56
+ source.disconnect();
57
+ };
58
+ }
59
+ </script>
60
+ </body>
61
  </html>
main.py CHANGED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket
2
+ from urllib.parse import urlparse, parse_qsl
3
+ from transcribe.whisper_llm_serve import PyWhiperCppServe
4
+ from uuid import uuid1
5
+ from logging import getLogger
6
+ import numpy as np
7
+ from transcribe.translatepipes import TranslatePipes
8
+ from contextlib import asynccontextmanager
9
+ from multiprocessing import Process, freeze_support
10
+ from fastapi.staticfiles import StaticFiles
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ async def get_audio_from_websocket(websocket)->np.array:
16
+ """
17
+ Receives audio buffer from websocket and creates a numpy array out of it.
18
+
19
+ Args:
20
+ websocket: The websocket to receive audio from.
21
+
22
+ Returns:
23
+ A numpy array containing the audio.
24
+ """
25
+ frame_data = await websocket.receive_bytes()
26
+ if frame_data == b"END_OF_AUDIO":
27
+ return False
28
+ return np.frombuffer(frame_data, dtype=np.int16).astype(np.float32) / 32768.0
29
+
30
+
31
+ @asynccontextmanager
32
+ async def lifespan(app:FastAPI):
33
+ global pipe
34
+ pipe = TranslatePipes()
35
+ pipe.wait_ready()
36
+ logger.info("Pipeline is ready.")
37
+ yield
38
+
39
+
40
+
41
+ app = FastAPI(lifespan=lifespan)
42
+ app.mount("/translate", StaticFiles(directory="frontend"),)
43
+ pipe = None
44
+
45
+ @app.websocket("/ws")
46
+ async def translate(websocket: WebSocket):
47
+ query_parameters_dict = websocket.query_params
48
+ from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
49
+ client = PyWhiperCppServe(
50
+ websocket,
51
+ pipe,
52
+ language="en",
53
+ client_uid=f"{uuid1()}",
54
+ )
55
+
56
+ if from_lang and to_lang:
57
+ client.set_lang(from_lang, to_lang)
58
+ logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
59
+ await websocket.accept()
60
+ while True:
61
+ frame_data = await get_audio_from_websocket(websocket)
62
+ client.add_frames(frame_data)
63
+
64
+
65
+
66
+ if __name__ == '__main__':
67
+
68
+ freeze_support()
69
+ import uvicorn
70
+ uvicorn.run(app, host="0.0.0.0", port=9090)
pyproject.toml CHANGED
@@ -18,6 +18,7 @@ dependencies = [
18
  "soundfile>=0.13.1",
19
  "torch>=2.6.0",
20
  "tqdm>=4.67.1",
 
21
  "websocket-client>=1.8.0",
22
  "websockets>=15.0.1",
23
  ]
 
18
  "soundfile>=0.13.1",
19
  "torch>=2.6.0",
20
  "tqdm>=4.67.1",
21
+ "uvicorn>=0.34.0",
22
  "websocket-client>=1.8.0",
23
  "websockets>=15.0.1",
24
  ]
transcribe/pipelines/pipe_vad.py CHANGED
@@ -4,7 +4,7 @@ from ..helpers.vadprocessor import SileroVADProcessor, FixedVADIterator
4
  import numpy as np
5
  from silero_vad import get_speech_timestamps,collect_chunks
6
  import torch
7
- import noisereduce as nr
8
 
9
 
10
  class VadPipe(BasePipe):
@@ -34,8 +34,8 @@ class VadPipe(BasePipe):
34
  return np.array([], dtype=np.float32)
35
 
36
 
37
- def reduce_noise(self, data):
38
- return nr.reduce_noise(y=data, sr=self.sample_rate)
39
 
40
 
41
  def process(self, in_data: MetaItem) -> MetaItem:
 
4
  import numpy as np
5
  from silero_vad import get_speech_timestamps,collect_chunks
6
  import torch
7
+ # import noisereduce as nr
8
 
9
 
10
  class VadPipe(BasePipe):
 
34
  return np.array([], dtype=np.float32)
35
 
36
 
37
+ # def reduce_noise(self, data):
38
+ # return nr.reduce_noise(y=data, sr=self.sample_rate)
39
 
40
 
41
  def process(self, in_data: MetaItem) -> MetaItem:
transcribe/whisper_llm_serve.py CHANGED
@@ -1,10 +1,8 @@
1
 
2
 
3
- import soundfile
4
- import multiprocessing as mp
5
  import numpy as np
6
  from logging import getLogger
7
-
8
  from .utils import save_to_wave
9
  import time
10
  import json
@@ -19,16 +17,11 @@ from .strategy import TripleTextBuffer, SegmentManager, segments_split, sequence
19
 
20
  logger = getLogger("TranslatorApp")
21
 
22
- translate_pipes = TranslatePipes()
23
- translate_pipes.wait_ready()
24
- logger.info("Pipeline is ready.")
25
-
26
-
27
 
28
 
29
  class PyWhiperCppServe(ServeClientBase):
30
 
31
- def __init__(self, websocket, language=None, dst_lang=None, client_uid=None,):
32
  super().__init__(client_uid, websocket)
33
  self.language = language
34
  self.dst_lang = dst_lang # 目标翻译语言
@@ -36,7 +29,7 @@ class PyWhiperCppServe(ServeClientBase):
36
  self._text_buffer = TripleTextBuffer()
37
  # 存储转录数据
38
  self._segment_manager = SegmentManager()
39
-
40
  self.lock = threading.Lock()
41
  self.frames_np = None
42
  self._frame_queue = queue.Queue()
@@ -71,7 +64,7 @@ class PyWhiperCppServe(ServeClientBase):
71
  def vad_merge(self):
72
  with self.lock:
73
  frame = self.frames_np.copy()
74
- item = translate_pipes.voice_detect(frame.tobytes())
75
  if item.audio != b'':
76
  frame_np = np.frombuffer(item.audio, dtype=np.float32)
77
  self.frames_np = frame_np.copy()
@@ -105,7 +98,7 @@ class PyWhiperCppServe(ServeClientBase):
105
  log_block("Audio buffer length", f"{audio_buffer.shape[0]/self.sample_rate:.2f}", "s")
106
  start_time = time.perf_counter()
107
 
108
- item = translate_pipes.transcrible(audio_buffer.tobytes(), self.language)
109
  segments = item.segments
110
  log_block("Whisper transcrible time", f"{(time.perf_counter() - start_time):.3f}", "s")
111
 
@@ -117,7 +110,7 @@ class PyWhiperCppServe(ServeClientBase):
117
  # return "sample english"
118
  log_block("LLM translate input", f"{text}")
119
  start_time = time.perf_counter()
120
- ret = translate_pipes.translate(text, self.language, self.dst_lang)
121
  translated_text = ret.translate_content
122
  log_block("LLM translate time", f"{(time.perf_counter() - start_time):.3f}", "s")
123
  log_block("LLM translate out", f"{translated_text}")
@@ -227,12 +220,12 @@ class PyWhiperCppServe(ServeClientBase):
227
  )
228
 
229
  def send_to_client(self, data:TransResult):
230
- try:
231
- self.websocket.send(
232
- Message(result=data, request_id=self.client_uid).model_dump_json(by_alias=True)
233
- )
234
- except Exception as e:
235
- logger.error(f"Sending data to client: {e}")
236
 
237
  def get_audio_chunk_for_processing(self):
238
  if self.frames_np.shape[0] >= self.sample_rate * 1:
 
1
 
2
 
 
 
3
  import numpy as np
4
  from logging import getLogger
5
+ import asyncio
6
  from .utils import save_to_wave
7
  import time
8
  import json
 
17
 
18
  logger = getLogger("TranslatorApp")
19
 
 
 
 
 
 
20
 
21
 
22
  class PyWhiperCppServe(ServeClientBase):
23
 
24
+ def __init__(self, websocket, pipe:TranslatePipes,language=None, dst_lang=None, client_uid=None,):
25
  super().__init__(client_uid, websocket)
26
  self.language = language
27
  self.dst_lang = dst_lang # 目标翻译语言
 
29
  self._text_buffer = TripleTextBuffer()
30
  # 存储转录数据
31
  self._segment_manager = SegmentManager()
32
+ self._translate_pipes = pipe
33
  self.lock = threading.Lock()
34
  self.frames_np = None
35
  self._frame_queue = queue.Queue()
 
64
  def vad_merge(self):
65
  with self.lock:
66
  frame = self.frames_np.copy()
67
+ item = self._translate_pipes.voice_detect(frame.tobytes())
68
  if item.audio != b'':
69
  frame_np = np.frombuffer(item.audio, dtype=np.float32)
70
  self.frames_np = frame_np.copy()
 
98
  log_block("Audio buffer length", f"{audio_buffer.shape[0]/self.sample_rate:.2f}", "s")
99
  start_time = time.perf_counter()
100
 
101
+ item = self._translate_pipes.transcrible(audio_buffer.tobytes(), self.language)
102
  segments = item.segments
103
  log_block("Whisper transcrible time", f"{(time.perf_counter() - start_time):.3f}", "s")
104
 
 
110
  # return "sample english"
111
  log_block("LLM translate input", f"{text}")
112
  start_time = time.perf_counter()
113
+ ret = self._translate_pipes.translate(text, self.language, self.dst_lang)
114
  translated_text = ret.translate_content
115
  log_block("LLM translate time", f"{(time.perf_counter() - start_time):.3f}", "s")
116
  log_block("LLM translate out", f"{translated_text}")
 
220
  )
221
 
222
  def send_to_client(self, data:TransResult):
223
+
224
+ coro = self.websocket.send_text(
225
+ Message(result=data, request_id=self.client_uid).model_dump_json(by_alias=True)
226
+ )
227
+ asyncio.run(coro)
228
+
229
 
230
  def get_audio_chunk_for_processing(self):
231
  if self.frames_np.shape[0] >= self.sample_rate * 1:
uv.lock CHANGED
@@ -207,6 +207,18 @@ wheels = [
207
  { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 },
208
  ]
209
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  [[package]]
211
  name = "colorama"
212
  version = "0.4.6"
@@ -278,6 +290,15 @@ wheels = [
278
  { url = "https://files.pythonhosted.org/packages/5c/0d/24d40adaacf77f133ac87a29c045ee0d81bb99732b09b6ff0251c76e5c67/fsspec-2025.3.1-py3-none-any.whl", hash = "sha256:2ce85886f37dfa12d5ad4764f1342efbf00ec0a4fe164f070038499d80142887", size = 194444 },
279
  ]
280
 
 
 
 
 
 
 
 
 
 
281
  [[package]]
282
  name = "humanfriendly"
283
  version = "10.0"
@@ -1205,6 +1226,7 @@ dependencies = [
1205
  { name = "soundfile" },
1206
  { name = "torch" },
1207
  { name = "tqdm" },
 
1208
  { name = "websocket-client" },
1209
  { name = "websockets" },
1210
  ]
@@ -1224,6 +1246,7 @@ requires-dist = [
1224
  { name = "soundfile", specifier = ">=0.13.1" },
1225
  { name = "torch", specifier = ">=2.6.0" },
1226
  { name = "tqdm", specifier = ">=4.67.1" },
 
1227
  { name = "websocket-client", specifier = ">=1.8.0" },
1228
  { name = "websockets", specifier = ">=15.0.1" },
1229
  ]
@@ -1268,6 +1291,19 @@ wheels = [
1268
  { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
1269
  ]
1270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1271
  [[package]]
1272
  name = "websocket-client"
1273
  version = "1.8.0"
 
207
  { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 },
208
  ]
209
 
210
+ [[package]]
211
+ name = "click"
212
+ version = "8.1.8"
213
+ source = { registry = "https://pypi.org/simple" }
214
+ dependencies = [
215
+ { name = "colorama", marker = "sys_platform == 'win32'" },
216
+ ]
217
+ sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 }
218
+ wheels = [
219
+ { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 },
220
+ ]
221
+
222
  [[package]]
223
  name = "colorama"
224
  version = "0.4.6"
 
290
  { url = "https://files.pythonhosted.org/packages/5c/0d/24d40adaacf77f133ac87a29c045ee0d81bb99732b09b6ff0251c76e5c67/fsspec-2025.3.1-py3-none-any.whl", hash = "sha256:2ce85886f37dfa12d5ad4764f1342efbf00ec0a4fe164f070038499d80142887", size = 194444 },
291
  ]
292
 
293
+ [[package]]
294
+ name = "h11"
295
+ version = "0.14.0"
296
+ source = { registry = "https://pypi.org/simple" }
297
+ sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418 }
298
+ wheels = [
299
+ { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 },
300
+ ]
301
+
302
  [[package]]
303
  name = "humanfriendly"
304
  version = "10.0"
 
1226
  { name = "soundfile" },
1227
  { name = "torch" },
1228
  { name = "tqdm" },
1229
+ { name = "uvicorn" },
1230
  { name = "websocket-client" },
1231
  { name = "websockets" },
1232
  ]
 
1246
  { name = "soundfile", specifier = ">=0.13.1" },
1247
  { name = "torch", specifier = ">=2.6.0" },
1248
  { name = "tqdm", specifier = ">=4.67.1" },
1249
+ { name = "uvicorn", specifier = ">=0.34.0" },
1250
  { name = "websocket-client", specifier = ">=1.8.0" },
1251
  { name = "websockets", specifier = ">=15.0.1" },
1252
  ]
 
1291
  { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
1292
  ]
1293
 
1294
+ [[package]]
1295
+ name = "uvicorn"
1296
+ version = "0.34.0"
1297
+ source = { registry = "https://pypi.org/simple" }
1298
+ dependencies = [
1299
+ { name = "click" },
1300
+ { name = "h11" },
1301
+ ]
1302
+ sdist = { url = "https://files.pythonhosted.org/packages/4b/4d/938bd85e5bf2edeec766267a5015ad969730bb91e31b44021dfe8b22df6c/uvicorn-0.34.0.tar.gz", hash = "sha256:404051050cd7e905de2c9a7e61790943440b3416f49cb409f965d9dcd0fa73e9", size = 76568 }
1303
+ wheels = [
1304
+ { url = "https://files.pythonhosted.org/packages/61/14/33a3a1352cfa71812a3a21e8c9bfb83f60b0011f5e36f2b1399d51928209/uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4", size = 62315 },
1305
+ ]
1306
+
1307
  [[package]]
1308
  name = "websocket-client"
1309
  version = "1.8.0"