minhpng commited on
Commit
a9e849e
·
1 Parent(s): 156bbee
libs/transformer/get_transcript.py CHANGED
@@ -2,30 +2,28 @@ import torch
2
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
3
 
4
 
5
-
6
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
7
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
8
-
9
- model_id = "distil-whisper/distil-large-v3"
10
-
11
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
12
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
13
- )
14
- model.to(device)
15
-
16
- processor = AutoProcessor.from_pretrained(model_id)
17
-
18
- pipe = pipeline(
19
- "automatic-speech-recognition",
20
- model=model,
21
- tokenizer=processor.tokenizer,
22
- feature_extractor=processor.feature_extractor,
23
- max_new_tokens=128,
24
- torch_dtype=torch_dtype,
25
- device=device,
26
- return_timestamps=True
27
- )
28
-
29
-
30
- result = pipe("https://static.langkingdom.com/user_playlist_practice_videos/2114103294b5c15605fd59773e948e58.mp3")
31
- print(result)
 
2
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
3
 
4
 
5
+ def get_transcript_gpu(url: str, model_id: str):
6
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
7
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
8
+
9
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
10
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
11
+ )
12
+ model.to(device)
13
+
14
+ processor = AutoProcessor.from_pretrained(model_id)
15
+
16
+ pipe = pipeline(
17
+ "automatic-speech-recognition",
18
+ model=model,
19
+ tokenizer=processor.tokenizer,
20
+ feature_extractor=processor.feature_extractor,
21
+ max_new_tokens=128,
22
+ torch_dtype=torch_dtype,
23
+ device=device,
24
+ return_timestamps=True
25
+ )
26
+
27
+
28
+ result = pipe(url)
29
+ return result.get("text"), result.get("chunks")
 
 
routers/get_transcript_transformer.py CHANGED
@@ -1,6 +1,9 @@
 
1
  import time
2
  from fastapi import APIRouter, Depends, HTTPException, status
3
 
 
 
4
  from libs.transformer.get_transcript_2 import get_transcribe_transformers
5
 
6
  from libs.header_api_auth import get_api_key
@@ -9,16 +12,30 @@ from libs.header_api_auth import get_api_key
9
  router = APIRouter(prefix="/get-transcript-transformer", tags=["transcript"])
10
 
11
  @router.get("/")
12
- def get_transcript(audio_path: str, api_key: str = Depends(get_api_key)):
13
  st = time.time()
14
 
 
 
 
 
 
 
 
 
 
15
  try:
16
- text, chunks = get_transcribe_transformers(audio_path)
17
  except Exception as error:
18
  raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"error>>>: {error}")
 
 
 
19
 
20
  listSentences = []
21
 
 
 
22
  for chunk in chunks:
23
  listSentences.append({
24
  "start_time": chunk.get("timestamp")[0],
 
1
+ import os
2
  import time
3
  from fastapi import APIRouter, Depends, HTTPException, status
4
 
5
+ from libs.convert_to_audio import convert_to_audio
6
+ from libs.transformer.get_transcript import get_transcript_gpu
7
  from libs.transformer.get_transcript_2 import get_transcribe_transformers
8
 
9
  from libs.header_api_auth import get_api_key
 
12
  router = APIRouter(prefix="/get-transcript-transformer", tags=["transcript"])
13
 
14
  @router.get("/")
15
+ def get_transcript(audio_path: str, model_size: str = "distil-whisper/distil-small.en", api_key: str = Depends(get_api_key)):
16
  st = time.time()
17
 
18
+ output_audio_folder = f"./cached/audio"
19
+
20
+ if not os.path.exists(output_audio_folder):
21
+ os.makedirs(output_audio_folder)
22
+
23
+ output_file = f"{output_audio_folder}/{audio_path.split('/')[-1].split(".")[0]}.mp3"
24
+
25
+ convert_to_audio(audio_path.strip(), output_file)
26
+
27
  try:
28
+ text, chunks = get_transcript_gpu(output_file, model_size)
29
  except Exception as error:
30
  raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"error>>>: {error}")
31
+ finally:
32
+ if os.path.exists(output_file):
33
+ os.remove(output_file)
34
 
35
  listSentences = []
36
 
37
+ print(chunks)
38
+
39
  for chunk in chunks:
40
  listSentences.append({
41
  "start_time": chunk.get("timestamp")[0],