Afrinetwork7 commited on
Commit
0ec2266
·
verified ·
1 Parent(s): 4813dbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -12
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException
2
  from fastapi.responses import JSONResponse, FileResponse
3
  from pydantic import BaseModel
4
  import numpy as np
@@ -10,12 +10,13 @@ import torch
10
  import librosa
11
  from transformers import Wav2Vec2ForCTC, AutoProcessor
12
  from pathlib import Path
 
 
13
 
14
  # Import functions from other modules
15
  from asr import transcribe, ASR_LANGUAGES
16
  from tts import synthesize, TTS_LANGUAGES
17
  from lid import identify
18
-
19
  from asr import ASR_SAMPLING_RATE, transcribe
20
 
21
  # Configure logging
@@ -26,7 +27,7 @@ app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages")
26
 
27
  # Define request models
28
  class AudioRequest(BaseModel):
29
- audio: str # Base64 encoded audio data
30
  language: str
31
 
32
  class TTSRequest(BaseModel):
@@ -34,19 +35,42 @@ class TTSRequest(BaseModel):
34
  language: str
35
  speed: float
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  @app.post("/transcribe")
38
  async def transcribe_audio(request: AudioRequest):
39
  try:
40
- audio_bytes = base64.b64decode(request.audio)
41
- audio_array, sample_rate = sf.read(io.BytesIO(audio_bytes))
 
42
  # Convert to mono if stereo
43
  if len(audio_array.shape) > 1:
44
  audio_array = audio_array.mean(axis=1)
 
45
  # Ensure audio_array is float32
46
  audio_array = audio_array.astype(np.float32)
 
47
  # Resample if necessary
48
  if sample_rate != ASR_SAMPLING_RATE:
49
  audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
 
50
  result = transcribe(audio_array, request.language)
51
  return JSONResponse(content={"transcription": result})
52
  except Exception as e:
@@ -57,15 +81,13 @@ async def transcribe_audio(request: AudioRequest):
57
  async def synthesize_speech(request: TTSRequest):
58
  try:
59
  audio, filtered_text = synthesize(request.text, request.language, request.speed)
60
-
61
  # Convert numpy array to bytes
62
  buffer = io.BytesIO()
63
  sf.write(buffer, audio, 22050, format='wav')
64
  buffer.seek(0)
65
-
66
  return FileResponse(
67
- buffer,
68
- media_type="audio/wav",
69
  headers={"Content-Disposition": "attachment; filename=synthesized_audio.wav"}
70
  )
71
  except Exception as e:
@@ -75,9 +97,8 @@ async def synthesize_speech(request: TTSRequest):
75
  @app.post("/identify")
76
  async def identify_language(request: AudioRequest):
77
  try:
78
- audio_bytes = base64.b64decode(request.audio)
79
- audio_array, sample_rate = sf.read(io.BytesIO(audio_bytes))
80
-
81
  result = identify(audio_array)
82
  return JSONResponse(content={"language_identification": result})
83
  except Exception as e:
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File
2
  from fastapi.responses import JSONResponse, FileResponse
3
  from pydantic import BaseModel
4
  import numpy as np
 
10
  import librosa
11
  from transformers import Wav2Vec2ForCTC, AutoProcessor
12
  from pathlib import Path
13
+ from moviepy.editor import VideoFileClip
14
+ import magic # For MIME type detection
15
 
16
  # Import functions from other modules
17
  from asr import transcribe, ASR_LANGUAGES
18
  from tts import synthesize, TTS_LANGUAGES
19
  from lid import identify
 
20
  from asr import ASR_SAMPLING_RATE, transcribe
21
 
22
  # Configure logging
 
27
 
28
  # Define request models
29
  class AudioRequest(BaseModel):
30
+ audio: str # Base64 encoded audio or video data
31
  language: str
32
 
33
  class TTSRequest(BaseModel):
 
35
  language: str
36
  speed: float
37
 
38
+ def detect_mime_type(input_bytes):
39
+ mime = magic.Magic(mime=True)
40
+ return mime.from_buffer(input_bytes)
41
+
42
+ def extract_audio(input_bytes):
43
+ mime_type = detect_mime_type(input_bytes)
44
+
45
+ if mime_type.startswith('audio/'):
46
+ return sf.read(io.BytesIO(input_bytes))
47
+ elif mime_type.startswith('video/'):
48
+ with io.BytesIO(input_bytes) as f:
49
+ video = VideoFileClip(f.name)
50
+ audio = video.audio
51
+ audio_array = audio.to_soundarray()
52
+ sample_rate = audio.fps
53
+ return audio_array, sample_rate
54
+ else:
55
+ raise ValueError(f"Unsupported MIME type: {mime_type}")
56
+
57
  @app.post("/transcribe")
58
  async def transcribe_audio(request: AudioRequest):
59
  try:
60
+ input_bytes = base64.b64decode(request.audio)
61
+ audio_array, sample_rate = extract_audio(input_bytes)
62
+
63
  # Convert to mono if stereo
64
  if len(audio_array.shape) > 1:
65
  audio_array = audio_array.mean(axis=1)
66
+
67
  # Ensure audio_array is float32
68
  audio_array = audio_array.astype(np.float32)
69
+
70
  # Resample if necessary
71
  if sample_rate != ASR_SAMPLING_RATE:
72
  audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
73
+
74
  result = transcribe(audio_array, request.language)
75
  return JSONResponse(content={"transcription": result})
76
  except Exception as e:
 
81
  async def synthesize_speech(request: TTSRequest):
82
  try:
83
  audio, filtered_text = synthesize(request.text, request.language, request.speed)
 
84
  # Convert numpy array to bytes
85
  buffer = io.BytesIO()
86
  sf.write(buffer, audio, 22050, format='wav')
87
  buffer.seek(0)
 
88
  return FileResponse(
89
+ buffer,
90
+ media_type="audio/wav",
91
  headers={"Content-Disposition": "attachment; filename=synthesized_audio.wav"}
92
  )
93
  except Exception as e:
 
97
  @app.post("/identify")
98
  async def identify_language(request: AudioRequest):
99
  try:
100
+ input_bytes = base64.b64decode(request.audio)
101
+ audio_array, sample_rate = extract_audio(input_bytes)
 
102
  result = identify(audio_array)
103
  return JSONResponse(content={"language_identification": result})
104
  except Exception as e: