ManBib commited on
Commit
ab0e749
·
1 Parent(s): 7406982

reset to only audio processing

Browse files
Files changed (1) hide show
  1. handler.py +13 -73
handler.py CHANGED
@@ -1,30 +1,9 @@
1
- import argparse
2
- import base64
3
  import io
4
- import logging
5
- import os
6
-
7
  from faster_whisper import WhisperModel
8
- from pydub import AudioSegment
9
-
10
- from file_processor import process_video
11
-
12
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
-
14
-
15
- def is_cdn_link(link_or_bytes):
16
- logging.info("Checking if the provided link is a CDN link...")
17
- if isinstance(link_or_bytes, bytes):
18
- return False
19
- return True
20
-
21
 
22
- def get_audio_bytes(audio_path):
23
- audio = AudioSegment.from_file(audio_path)
24
- buffer = io.BytesIO()
25
- audio.export(buffer, format='mp3')
26
- buffer.seek(0)
27
- return buffer
28
 
29
 
30
  class EndpointHandler:
@@ -32,26 +11,21 @@ class EndpointHandler:
32
  self.model = WhisperModel("large-v3", num_workers=30)
33
 
34
  def __call__(self, data: dict[str, str]):
35
- inputs = data.pop("inputs")
36
-
37
  language = data.pop("language", "de")
38
  task = data.pop("task", "transcribe")
39
- response = {}
40
- audio_path = None
41
 
42
- if is_cdn_link(inputs):
43
- slides, audio_path = process_video(inputs)
44
- audio_bytes = get_audio_bytes(audio_path)
45
- slides_list = [slide.to_dict() for slide in slides]
46
- response.update({"slides": slides_list})
47
- else:
48
- audio_bytes_decoded = base64.b64decode(inputs)
49
- logging.debug(f"Decoded Bytes Length: {len(audio_bytes_decoded)}")
50
- audio_bytes = io.BytesIO(audio_bytes_decoded)
51
 
 
52
  logging.info("Running inference...")
53
- segments, info = self.model.transcribe(audio_bytes, language=language, task=task, )
54
 
 
55
  full_text = []
56
  for segment in segments:
57
  full_text.append({"segmentId": segment.id,
@@ -66,38 +40,4 @@ class EndpointHandler:
66
  logging.info("segment " + str(segment.id) + " transcribed")
67
  logging.info("Inference completed.")
68
 
69
- response.update({"audios": full_text})
70
- logging.debug(response)
71
- if audio_path:
72
- os.remove(audio_path)
73
- return response
74
-
75
-
76
- if __name__ == '__main__':
77
- Parser = argparse.ArgumentParser(description="EndpointHandler")
78
- Parser.add_argument("-p", "--path")
79
- Parser.add_argument("-l", "--language", default="de")
80
- Parser.add_argument("-t", "--task", default="transcribe")
81
- Parser.add_argument("--type", default="video")
82
- Args = Parser.parse_args()
83
-
84
- handler = EndpointHandler()
85
-
86
-
87
- # if is_cdn_link(Args.path):
88
- # test_inputs = Args.path
89
- # else:
90
- audio = AudioSegment.from_mp3(r"C:\Users\mbabu\AppData\Local\Temp\tmpsezkw2i5.mp3")
91
- buffer = io.BytesIO()
92
- audio.export(buffer, format="mp3")
93
- mp3_bytes = buffer.getvalue()
94
- test_inputs = base64.b64encode(mp3_bytes)
95
-
96
- sample_data = {
97
- "inputs": test_inputs,
98
- "language": Args.language,
99
- "task": Args.task,
100
- }
101
-
102
- test = handler(sample_data)
103
- print(test)
 
 
 
1
  import io
2
+ import base64
 
 
3
  from faster_whisper import WhisperModel
4
+ import logging
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ logging.basicConfig(level=logging.DEBUG)
 
 
 
 
 
7
 
8
 
9
  class EndpointHandler:
 
11
  self.model = WhisperModel("large-v3", num_workers=30)
12
 
13
  def __call__(self, data: dict[str, str]):
14
+ # process inputs
15
+ inputs = data.pop("inputs", data)
16
  language = data.pop("language", "de")
17
  task = data.pop("task", "transcribe")
 
 
18
 
19
+ # Decode base64 string to bytes
20
+ audio_bytes_decoded = base64.b64decode(inputs)
21
+ logging.debug(f"Decoded Bytes Length: {len(audio_bytes_decoded)}")
22
+ audio_bytes = io.BytesIO(audio_bytes_decoded)
 
 
 
 
 
23
 
24
+ # run inference pipeline
25
  logging.info("Running inference...")
26
+ segments, info = self.model.transcribe(audio_bytes, language=language, task=task)
27
 
28
+ # postprocess the prediction
29
  full_text = []
30
  for segment in segments:
31
  full_text.append({"segmentId": segment.id,
 
40
  logging.info("segment " + str(segment.id) + " transcribed")
41
  logging.info("Inference completed.")
42
 
43
+ return full_text