|
import argparse |
|
import base64 |
|
import io |
|
import logging |
|
import os |
|
|
|
from faster_whisper import WhisperModel |
|
from pydub import AudioSegment |
|
|
|
from file_processor import process_video |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
def is_cdn_link(link_or_bytes): |
|
logging.info("Checking if the provided link is a CDN link...") |
|
if isinstance(link_or_bytes, bytes): |
|
return False |
|
return True |
|
|
|
|
|
def get_audio_bytes(audio_path): |
|
audio = AudioSegment.from_file(audio_path) |
|
buffer = io.BytesIO() |
|
audio.export(buffer, format='mp3') |
|
buffer.seek(0) |
|
return buffer |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.model = WhisperModel("large-v3", num_workers=30) |
|
|
|
def __call__(self, data: dict[str, str]): |
|
inputs = data.pop("inputs") |
|
|
|
language = data.pop("language", "de") |
|
task = data.pop("task", "transcribe") |
|
response = {} |
|
audio_path = None |
|
|
|
if is_cdn_link(inputs): |
|
slides, audio_path = process_video(inputs) |
|
audio_bytes = get_audio_bytes(audio_path) |
|
slides_list = [slide.to_dict() for slide in slides] |
|
response.update({"slides": slides_list}) |
|
else: |
|
audio_bytes_decoded = base64.b64decode(inputs) |
|
logging.debug(f"Decoded Bytes Length: {len(audio_bytes_decoded)}") |
|
audio_bytes = io.BytesIO(audio_bytes_decoded) |
|
|
|
logging.info("Running inference...") |
|
segments, info = self.model.transcribe(audio_bytes, language=language, task=task, ) |
|
|
|
full_text = [] |
|
for segment in segments: |
|
full_text.append({"segmentId": segment.id, |
|
"text": segment.text, |
|
"timestamps": { |
|
"start": segment.start, |
|
"end": segment.end |
|
} |
|
}) |
|
|
|
if segment.id % 100 == 0: |
|
logging.info("segment " + str(segment.id) + " transcribed") |
|
logging.info("Inference completed.") |
|
|
|
response.update({"audios": full_text}) |
|
logging.debug(response) |
|
if audio_path: |
|
os.remove(audio_path) |
|
return response |
|
|
|
|
|
if __name__ == '__main__': |
|
Parser = argparse.ArgumentParser(description="EndpointHandler") |
|
Parser.add_argument("-p", "--path") |
|
Parser.add_argument("-l", "--language", default="de") |
|
Parser.add_argument("-t", "--task", default="transcribe") |
|
Parser.add_argument("--type", default="video") |
|
Args = Parser.parse_args() |
|
|
|
handler = EndpointHandler() |
|
|
|
|
|
|
|
|
|
|
|
audio = AudioSegment.from_mp3(r"C:\Users\mbabu\AppData\Local\Temp\tmpsezkw2i5.mp3") |
|
buffer = io.BytesIO() |
|
audio.export(buffer, format="mp3") |
|
mp3_bytes = buffer.getvalue() |
|
test_inputs = base64.b64encode(mp3_bytes) |
|
|
|
sample_data = { |
|
"inputs": test_inputs, |
|
"language": Args.language, |
|
"task": Args.task, |
|
} |
|
|
|
test = handler(sample_data) |
|
print(test) |
|
|