faster-whisper-readme / handler.py
ManBib's picture
reset process from cpu to gpu
23d99fb
raw
history blame
3.15 kB
import argparse
import base64
import io
import logging
import os
from faster_whisper import WhisperModel
from pydub import AudioSegment
from file_processor import process_video
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def is_cdn_link(link_or_bytes):
logging.info("Checking if the provided link is a CDN link...")
if isinstance(link_or_bytes, bytes):
return False
return True
def get_audio_bytes(audio_path):
audio = AudioSegment.from_file(audio_path)
buffer = io.BytesIO()
audio.export(buffer, format='mp3')
buffer.seek(0)
return buffer
class EndpointHandler:
def __init__(self, path=""):
self.model = WhisperModel("large-v3", num_workers=30)
def __call__(self, data: dict[str, str]):
inputs = data.pop("inputs")
language = data.pop("language", "de")
task = data.pop("task", "transcribe")
response = {}
audio_path = None
if is_cdn_link(inputs):
slides, audio_path = process_video(inputs)
audio_bytes = get_audio_bytes(audio_path)
slides_list = [slide.to_dict() for slide in slides]
response.update({"slides": slides_list})
else:
audio_bytes_decoded = base64.b64decode(inputs)
logging.debug(f"Decoded Bytes Length: {len(audio_bytes_decoded)}")
audio_bytes = io.BytesIO(audio_bytes_decoded)
logging.info("Running inference...")
segments, info = self.model.transcribe(audio_bytes, language=language, task=task, )
full_text = []
for segment in segments:
full_text.append({"segmentId": segment.id,
"text": segment.text,
"timestamps": {
"start": segment.start,
"end": segment.end
}
})
if segment.id % 100 == 0:
logging.info("segment " + str(segment.id) + " transcribed")
logging.info("Inference completed.")
response.update({"audios": full_text})
logging.debug(response)
if audio_path:
os.remove(audio_path)
return response
if __name__ == '__main__':
Parser = argparse.ArgumentParser(description="EndpointHandler")
Parser.add_argument("-p", "--path")
Parser.add_argument("-l", "--language", default="de")
Parser.add_argument("-t", "--task", default="transcribe")
Parser.add_argument("--type", default="video")
Args = Parser.parse_args()
handler = EndpointHandler()
# if is_cdn_link(Args.path):
# test_inputs = Args.path
# else:
audio = AudioSegment.from_mp3(r"C:\Users\mbabu\AppData\Local\Temp\tmpsezkw2i5.mp3")
buffer = io.BytesIO()
audio.export(buffer, format="mp3")
mp3_bytes = buffer.getvalue()
test_inputs = base64.b64encode(mp3_bytes)
sample_data = {
"inputs": test_inputs,
"language": Args.language,
"task": Args.task,
}
test = handler(sample_data)
print(test)