Spaces:
Edmond98
/
Running on TPU v5e

Afrinetwork7 commited on
Commit
2300584
1 Parent(s): fbec879

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -115
app.py CHANGED
@@ -1,19 +1,16 @@
1
  import logging
2
  import math
3
- import os
4
- import tempfile
5
  import time
 
 
6
  from typing import Dict, Any
7
  from functools import wraps
8
 
9
- import yt_dlp as youtube_dl
10
- from fastapi import FastAPI, File, UploadFile, Depends, HTTPException
11
- from fastapi.responses import HTMLResponse
12
  from fastapi.encoders import jsonable_encoder
13
  from pydantic import BaseModel
14
  import jax.numpy as jnp
15
  import numpy as np
16
- from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
17
  from transformers.pipelines.audio_utils import ffmpeg_read
18
  from whisper_jax import FlaxWhisperPipline
19
 
@@ -33,7 +30,6 @@ BATCH_SIZE = 32
33
  CHUNK_LENGTH_S = 30
34
  NUM_PROC = 32
35
  FILE_LIMIT_MB = 10000
36
- YT_LENGTH_LIMIT_S = 15000 # limit to 2 hour YouTube files
37
 
38
  pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
39
  stride_length_s = CHUNK_LENGTH_S / 6
@@ -54,11 +50,7 @@ compile_time = time.time() - start
54
  logger.debug(f"Compiled in {compile_time}s")
55
 
56
  class TranscribeAudioRequest(BaseModel):
57
- task: str = "transcribe"
58
- return_timestamps: bool = False
59
-
60
- class TranscribeYouTubeRequest(BaseModel):
61
- yt_url: str
62
  task: str = "transcribe"
63
  return_timestamps: bool = False
64
 
@@ -79,41 +71,33 @@ def timeit(func):
79
  @app.post("/transcribe_audio")
80
  @timeit
81
  async def transcribe_chunked_audio(
82
- audio_file: UploadFile = File(...),
83
- request: TranscribeAudioRequest = Depends()
84
  ) -> Dict[str, Any]:
85
  logger.debug("Starting transcribe_chunked_audio function")
86
  logger.debug(f"Received parameters - task: {request.task}, return_timestamps: {request.return_timestamps}")
87
 
88
- logger.debug("Checking for audio file...")
89
- if not audio_file:
90
- logger.warning("No audio file")
91
- raise HTTPException(status_code=400, detail="No audio file submitted!")
92
-
93
- logger.debug(f"Audio file received: {audio_file.filename}")
94
-
95
  try:
96
- # Read the file content
97
- file_content = await audio_file.read()
98
- file_size = len(file_content)
99
  file_size_mb = file_size / (1024 * 1024)
100
- logger.debug(f"File size: {file_size} bytes ({file_size_mb:.2f}MB)")
101
  except Exception as e:
102
- logger.error(f"Error reading file: {str(e)}", exc_info=True)
103
- raise HTTPException(status_code=500, detail=f"Error reading file: {str(e)}")
104
 
105
  if file_size_mb > FILE_LIMIT_MB:
106
  logger.warning(f"Max file size exceeded: {file_size_mb:.2f}MB > {FILE_LIMIT_MB}MB")
107
  raise HTTPException(status_code=400, detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.")
108
 
109
  try:
110
- logger.debug("Performing ffmpeg read on audio file")
111
- inputs = ffmpeg_read(file_content, pipeline.feature_extractor.sampling_rate)
112
  inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
113
  logger.debug("ffmpeg read completed successfully")
114
  except Exception as e:
115
  logger.error(f"Error in ffmpeg read: {str(e)}", exc_info=True)
116
- raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")
117
 
118
  logger.debug("Calling tqdm_generate to transcribe audio")
119
  try:
@@ -130,51 +114,6 @@ async def transcribe_chunked_audio(
130
  "timing_info": timing_info
131
  })
132
 
133
- @app.post("/transcribe_youtube")
134
- @timeit
135
- async def transcribe_youtube(request: TranscribeYouTubeRequest) -> Dict[str, Any]:
136
- logger.debug("Loading YouTube file...")
137
- try:
138
- html_embed_str = _return_yt_html_embed(request.yt_url)
139
- except Exception as e:
140
- logger.error("Error generating YouTube HTML embed:", exc_info=True)
141
- raise HTTPException(status_code=500, detail="Error generating YouTube HTML embed")
142
-
143
- with tempfile.TemporaryDirectory() as tmpdirname:
144
- filepath = os.path.join(tmpdirname, "video.mp4")
145
- try:
146
- logger.debug("Downloading YouTube audio...")
147
- download_yt_audio(request.yt_url, filepath)
148
- except Exception as e:
149
- logger.error("Error downloading YouTube audio:", exc_info=True)
150
- raise HTTPException(status_code=500, detail="Error downloading YouTube audio")
151
-
152
- try:
153
- logger.debug(f"Opening downloaded audio file: {filepath}")
154
- with open(filepath, "rb") as f:
155
- inputs = f.read()
156
- except Exception as e:
157
- logger.error("Error reading downloaded audio file:", exc_info=True)
158
- raise HTTPException(status_code=500, detail="Error reading downloaded audio file")
159
-
160
- inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
161
- inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
162
- logger.debug("Done loading YouTube file")
163
-
164
- try:
165
- logger.debug("Calling tqdm_generate to transcribe YouTube audio")
166
- text, runtime, timing_info = tqdm_generate(inputs, task=request.task, return_timestamps=request.return_timestamps)
167
- except Exception as e:
168
- logger.error("Error transcribing YouTube audio:", exc_info=True)
169
- raise HTTPException(status_code=500, detail="Error transcribing YouTube audio")
170
-
171
- return jsonable_encoder({
172
- "html_embed": html_embed_str,
173
- "text": text,
174
- "runtime": runtime,
175
- "timing_info": timing_info
176
- })
177
-
178
  def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
179
  start_time = time.time()
180
  logger.debug(f"Starting tqdm_generate - task: {task}, return_timestamps: {return_timestamps}")
@@ -236,46 +175,6 @@ def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
236
  "total_processing_time": total_processing_time
237
  }
238
 
239
- def _return_yt_html_embed(yt_url):
240
- video_id = yt_url.split("?v=")[-1]
241
- HTML_str = (
242
- f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
243
- " </center>"
244
- )
245
- return HTML_str
246
-
247
- def download_yt_audio(yt_url, filename):
248
- info_loader = youtube_dl.YoutubeDL()
249
- try:
250
- logger.debug(f"Extracting info for YouTube URL: {yt_url}")
251
- info = info_loader.extract_info(yt_url, download=False)
252
- except youtube_dl.utils.DownloadError as err:
253
- logger.error("Error extracting YouTube info:", exc_info=True)
254
- raise HTTPException(status_code=400, detail=str(err))
255
-
256
- file_length = info["duration_string"]
257
- file_h_m_s = file_length.split(":")
258
- file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
259
- if len(file_h_m_s) == 1:
260
- file_h_m_s.insert(0, 0)
261
- if len(file_h_m_s) == 2:
262
- file_h_m_s.insert(0, 0)
263
-
264
- file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
265
- if file_length_s > YT_LENGTH_LIMIT_S:
266
- yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
267
- file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
268
- raise HTTPException(status_code=400, detail=f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
269
-
270
- ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
271
- with youtube_dl.YoutubeDL(ydl_opts) as ydl:
272
- try:
273
- logger.debug(f"Downloading YouTube audio to {filename}")
274
- ydl.download([yt_url])
275
- except youtube_dl.utils.ExtractorError as err:
276
- logger.error("Error downloading YouTube audio:", exc_info=True)
277
- raise HTTPException(status_code=400, detail=str(err))
278
-
279
  def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
280
  if seconds is not None:
281
  milliseconds = round(seconds * 1000.0)
 
1
  import logging
2
  import math
 
 
3
  import time
4
+ import base64
5
+ import io
6
  from typing import Dict, Any
7
  from functools import wraps
8
 
9
+ from fastapi import FastAPI, Depends, HTTPException
 
 
10
  from fastapi.encoders import jsonable_encoder
11
  from pydantic import BaseModel
12
  import jax.numpy as jnp
13
  import numpy as np
 
14
  from transformers.pipelines.audio_utils import ffmpeg_read
15
  from whisper_jax import FlaxWhisperPipline
16
 
 
30
  CHUNK_LENGTH_S = 30
31
  NUM_PROC = 32
32
  FILE_LIMIT_MB = 10000
 
33
 
34
  pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
35
  stride_length_s = CHUNK_LENGTH_S / 6
 
50
  logger.debug(f"Compiled in {compile_time}s")
51
 
52
  class TranscribeAudioRequest(BaseModel):
53
+ audio_base64: str
 
 
 
 
54
  task: str = "transcribe"
55
  return_timestamps: bool = False
56
 
 
71
  @app.post("/transcribe_audio")
72
  @timeit
73
  async def transcribe_chunked_audio(
74
+ request: TranscribeAudioRequest
 
75
  ) -> Dict[str, Any]:
76
  logger.debug("Starting transcribe_chunked_audio function")
77
  logger.debug(f"Received parameters - task: {request.task}, return_timestamps: {request.return_timestamps}")
78
 
 
 
 
 
 
 
 
79
  try:
80
+ # Decode base64 audio data
81
+ audio_data = base64.b64decode(request.audio_base64)
82
+ file_size = len(audio_data)
83
  file_size_mb = file_size / (1024 * 1024)
84
+ logger.debug(f"Decoded audio data size: {file_size} bytes ({file_size_mb:.2f}MB)")
85
  except Exception as e:
86
+ logger.error(f"Error decoding base64 audio data: {str(e)}", exc_info=True)
87
+ raise HTTPException(status_code=400, detail=f"Error decoding base64 audio data: {str(e)}")
88
 
89
  if file_size_mb > FILE_LIMIT_MB:
90
  logger.warning(f"Max file size exceeded: {file_size_mb:.2f}MB > {FILE_LIMIT_MB}MB")
91
  raise HTTPException(status_code=400, detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.")
92
 
93
  try:
94
+ logger.debug("Performing ffmpeg read on audio data")
95
+ inputs = ffmpeg_read(audio_data, pipeline.feature_extractor.sampling_rate)
96
  inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
97
  logger.debug("ffmpeg read completed successfully")
98
  except Exception as e:
99
  logger.error(f"Error in ffmpeg read: {str(e)}", exc_info=True)
100
+ raise HTTPException(status_code=500, detail=f"Error processing audio data: {str(e)}")
101
 
102
  logger.debug("Calling tqdm_generate to transcribe audio")
103
  try:
 
114
  "timing_info": timing_info
115
  })
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
118
  start_time = time.time()
119
  logger.debug(f"Starting tqdm_generate - task: {task}, return_timestamps: {return_timestamps}")
 
175
  "total_processing_time": total_processing_time
176
  }
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
179
  if seconds is not None:
180
  milliseconds = round(seconds * 1000.0)