Spaces:
Edmond98
/
Running on TPU v5e

Afrinetwork7 commited on
Commit
fbec879
1 Parent(s): 29850f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -10
app.py CHANGED
@@ -4,10 +4,13 @@ import os
4
  import tempfile
5
  import time
6
  from typing import Dict, Any
 
7
 
8
  import yt_dlp as youtube_dl
9
- from fastapi import FastAPI, UploadFile, Form, HTTPException
10
  from fastapi.responses import HTMLResponse
 
 
11
  import jax.numpy as jnp
12
  import numpy as np
13
  from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
@@ -50,7 +53,17 @@ random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, retur
50
  compile_time = time.time() - start
51
  logger.debug(f"Compiled in {compile_time}s")
52
 
 
 
 
 
 
 
 
 
 
53
  def timeit(func):
 
54
  async def wrapper(*args, **kwargs):
55
  start_time = time.time()
56
  result = await func(*args, **kwargs)
@@ -65,9 +78,12 @@ def timeit(func):
65
 
66
  @app.post("/transcribe_audio")
67
  @timeit
68
- async def transcribe_chunked_audio(audio_file: UploadFile, task: str = "transcribe", return_timestamps: bool = False) -> Dict[str, Any]:
 
 
 
69
  logger.debug("Starting transcribe_chunked_audio function")
70
- logger.debug(f"Received parameters - task: {task}, return_timestamps: {return_timestamps}")
71
 
72
  logger.debug("Checking for audio file...")
73
  if not audio_file:
@@ -101,21 +117,25 @@ async def transcribe_chunked_audio(audio_file: UploadFile, task: str = "transcri
101
 
102
  logger.debug("Calling tqdm_generate to transcribe audio")
103
  try:
104
- text, runtime, timing_info = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
105
  logger.debug(f"Transcription completed. Runtime: {runtime:.2f}s")
106
  except Exception as e:
107
  logger.error(f"Error in tqdm_generate: {str(e)}", exc_info=True)
108
  raise HTTPException(status_code=500, detail=f"Error transcribing audio: {str(e)}")
109
 
110
  logger.debug("Transcribe_chunked_audio function completed successfully")
111
- return {"text": text, "runtime": runtime, "timing_info": timing_info}
 
 
 
 
112
 
113
  @app.post("/transcribe_youtube")
114
  @timeit
115
- async def transcribe_youtube(yt_url: str = Form(...), task: str = "transcribe", return_timestamps: bool = False) -> Dict[str, Any]:
116
  logger.debug("Loading YouTube file...")
117
  try:
118
- html_embed_str = _return_yt_html_embed(yt_url)
119
  except Exception as e:
120
  logger.error("Error generating YouTube HTML embed:", exc_info=True)
121
  raise HTTPException(status_code=500, detail="Error generating YouTube HTML embed")
@@ -124,7 +144,7 @@ async def transcribe_youtube(yt_url: str = Form(...), task: str = "transcribe",
124
  filepath = os.path.join(tmpdirname, "video.mp4")
125
  try:
126
  logger.debug("Downloading YouTube audio...")
127
- download_yt_audio(yt_url, filepath)
128
  except Exception as e:
129
  logger.error("Error downloading YouTube audio:", exc_info=True)
130
  raise HTTPException(status_code=500, detail="Error downloading YouTube audio")
@@ -143,12 +163,17 @@ async def transcribe_youtube(yt_url: str = Form(...), task: str = "transcribe",
143
 
144
  try:
145
  logger.debug("Calling tqdm_generate to transcribe YouTube audio")
146
- text, runtime, timing_info = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
147
  except Exception as e:
148
  logger.error("Error transcribing YouTube audio:", exc_info=True)
149
  raise HTTPException(status_code=500, detail="Error transcribing YouTube audio")
150
 
151
- return {"html_embed": html_embed_str, "text": text, "runtime": runtime, "timing_info": timing_info}
 
 
 
 
 
152
 
153
  def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
154
  start_time = time.time()
 
4
  import tempfile
5
  import time
6
  from typing import Dict, Any
7
+ from functools import wraps
8
 
9
  import yt_dlp as youtube_dl
10
+ from fastapi import FastAPI, File, UploadFile, Depends, HTTPException
11
  from fastapi.responses import HTMLResponse
12
+ from fastapi.encoders import jsonable_encoder
13
+ from pydantic import BaseModel
14
  import jax.numpy as jnp
15
  import numpy as np
16
  from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
 
53
  compile_time = time.time() - start
54
  logger.debug(f"Compiled in {compile_time}s")
55
 
56
+ class TranscribeAudioRequest(BaseModel):
57
+ task: str = "transcribe"
58
+ return_timestamps: bool = False
59
+
60
+ class TranscribeYouTubeRequest(BaseModel):
61
+ yt_url: str
62
+ task: str = "transcribe"
63
+ return_timestamps: bool = False
64
+
65
  def timeit(func):
66
+ @wraps(func)
67
  async def wrapper(*args, **kwargs):
68
  start_time = time.time()
69
  result = await func(*args, **kwargs)
 
78
 
79
  @app.post("/transcribe_audio")
80
  @timeit
81
+ async def transcribe_chunked_audio(
82
+ audio_file: UploadFile = File(...),
83
+ request: TranscribeAudioRequest = Depends()
84
+ ) -> Dict[str, Any]:
85
  logger.debug("Starting transcribe_chunked_audio function")
86
+ logger.debug(f"Received parameters - task: {request.task}, return_timestamps: {request.return_timestamps}")
87
 
88
  logger.debug("Checking for audio file...")
89
  if not audio_file:
 
117
 
118
  logger.debug("Calling tqdm_generate to transcribe audio")
119
  try:
120
+ text, runtime, timing_info = tqdm_generate(inputs, task=request.task, return_timestamps=request.return_timestamps)
121
  logger.debug(f"Transcription completed. Runtime: {runtime:.2f}s")
122
  except Exception as e:
123
  logger.error(f"Error in tqdm_generate: {str(e)}", exc_info=True)
124
  raise HTTPException(status_code=500, detail=f"Error transcribing audio: {str(e)}")
125
 
126
  logger.debug("Transcribe_chunked_audio function completed successfully")
127
+ return jsonable_encoder({
128
+ "text": text,
129
+ "runtime": runtime,
130
+ "timing_info": timing_info
131
+ })
132
 
133
  @app.post("/transcribe_youtube")
134
  @timeit
135
+ async def transcribe_youtube(request: TranscribeYouTubeRequest) -> Dict[str, Any]:
136
  logger.debug("Loading YouTube file...")
137
  try:
138
+ html_embed_str = _return_yt_html_embed(request.yt_url)
139
  except Exception as e:
140
  logger.error("Error generating YouTube HTML embed:", exc_info=True)
141
  raise HTTPException(status_code=500, detail="Error generating YouTube HTML embed")
 
144
  filepath = os.path.join(tmpdirname, "video.mp4")
145
  try:
146
  logger.debug("Downloading YouTube audio...")
147
+ download_yt_audio(request.yt_url, filepath)
148
  except Exception as e:
149
  logger.error("Error downloading YouTube audio:", exc_info=True)
150
  raise HTTPException(status_code=500, detail="Error downloading YouTube audio")
 
163
 
164
  try:
165
  logger.debug("Calling tqdm_generate to transcribe YouTube audio")
166
+ text, runtime, timing_info = tqdm_generate(inputs, task=request.task, return_timestamps=request.return_timestamps)
167
  except Exception as e:
168
  logger.error("Error transcribing YouTube audio:", exc_info=True)
169
  raise HTTPException(status_code=500, detail="Error transcribing YouTube audio")
170
 
171
+ return jsonable_encoder({
172
+ "html_embed": html_embed_str,
173
+ "text": text,
174
+ "runtime": runtime,
175
+ "timing_info": timing_info
176
+ })
177
 
178
  def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
179
  start_time = time.time()