File size: 3,148 Bytes
2b16bc4 91251fa 96d549d 91251fa 2b16bc4 91251fa 96d549d 2b16bc4 96d549d 2b16bc4 96d549d 2b16bc4 91251fa 23d99fb 91251fa 96d549d 91251fa 96d549d 2b16bc4 91251fa 2b16bc4 96d549d 91251fa 03efb5d 91251fa 96d549d 2b16bc4 96d549d 2b16bc4 03efb5d 2b16bc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import argparse
import base64
import io
import logging
import os
from faster_whisper import WhisperModel
from pydub import AudioSegment
from file_processor import process_video
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def is_cdn_link(link_or_bytes):
logging.info("Checking if the provided link is a CDN link...")
if isinstance(link_or_bytes, bytes):
return False
return True
def get_audio_bytes(audio_path):
audio = AudioSegment.from_file(audio_path)
buffer = io.BytesIO()
audio.export(buffer, format='mp3')
buffer.seek(0)
return buffer
class EndpointHandler:
def __init__(self, path=""):
self.model = WhisperModel("large-v3", num_workers=30)
def __call__(self, data: dict[str, str]):
inputs = data.pop("inputs")
language = data.pop("language", "de")
task = data.pop("task", "transcribe")
response = {}
audio_path = None
if is_cdn_link(inputs):
slides, audio_path = process_video(inputs)
audio_bytes = get_audio_bytes(audio_path)
slides_list = [slide.to_dict() for slide in slides]
response.update({"slides": slides_list})
else:
audio_bytes_decoded = base64.b64decode(inputs)
logging.debug(f"Decoded Bytes Length: {len(audio_bytes_decoded)}")
audio_bytes = io.BytesIO(audio_bytes_decoded)
logging.info("Running inference...")
segments, info = self.model.transcribe(audio_bytes, language=language, task=task, )
full_text = []
for segment in segments:
full_text.append({"segmentId": segment.id,
"text": segment.text,
"timestamps": {
"start": segment.start,
"end": segment.end
}
})
if segment.id % 100 == 0:
logging.info("segment " + str(segment.id) + " transcribed")
logging.info("Inference completed.")
response.update({"audios": full_text})
logging.debug(response)
if audio_path:
os.remove(audio_path)
return response
if __name__ == '__main__':
Parser = argparse.ArgumentParser(description="EndpointHandler")
Parser.add_argument("-p", "--path")
Parser.add_argument("-l", "--language", default="de")
Parser.add_argument("-t", "--task", default="transcribe")
Parser.add_argument("--type", default="video")
Args = Parser.parse_args()
handler = EndpointHandler()
# if is_cdn_link(Args.path):
# test_inputs = Args.path
# else:
audio = AudioSegment.from_mp3(r"C:\Users\mbabu\AppData\Local\Temp\tmpsezkw2i5.mp3")
buffer = io.BytesIO()
audio.export(buffer, format="mp3")
mp3_bytes = buffer.getvalue()
test_inputs = base64.b64encode(mp3_bytes)
sample_data = {
"inputs": test_inputs,
"language": Args.language,
"task": Args.task,
}
test = handler(sample_data)
print(test)
|