File size: 3,148 Bytes
2b16bc4
91251fa
96d549d
91251fa
2b16bc4
91251fa
96d549d
2b16bc4
96d549d
2b16bc4
96d549d
2b16bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91251fa
 
 
 
23d99fb
91251fa
 
96d549d
 
91251fa
 
96d549d
2b16bc4
91251fa
2b16bc4
 
 
96d549d
 
 
 
 
 
91251fa
 
03efb5d
91251fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96d549d
 
2b16bc4
 
96d549d
2b16bc4
 
 
 
 
 
 
 
 
 
 
 
 
03efb5d
 
 
 
 
 
 
 
2b16bc4
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import base64
import io
import logging
import os

from faster_whisper import WhisperModel
from pydub import AudioSegment

from file_processor import process_video

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def is_cdn_link(link_or_bytes):
    logging.info("Checking if the provided link is a CDN link...")
    if isinstance(link_or_bytes, bytes):
        return False
    return True


def get_audio_bytes(audio_path):
    audio = AudioSegment.from_file(audio_path)
    buffer = io.BytesIO()
    audio.export(buffer, format='mp3')
    buffer.seek(0)
    return buffer


class EndpointHandler:
    def __init__(self, path=""):
        self.model = WhisperModel("large-v3", num_workers=30)

    def __call__(self, data: dict[str, str]):
        inputs = data.pop("inputs")

        language = data.pop("language", "de")
        task = data.pop("task", "transcribe")
        response = {}
        audio_path = None

        if is_cdn_link(inputs):
            slides, audio_path = process_video(inputs)
            audio_bytes = get_audio_bytes(audio_path)
            slides_list = [slide.to_dict() for slide in slides]
            response.update({"slides": slides_list})
        else:
            audio_bytes_decoded = base64.b64decode(inputs)
            logging.debug(f"Decoded Bytes Length: {len(audio_bytes_decoded)}")
            audio_bytes = io.BytesIO(audio_bytes_decoded)

        logging.info("Running inference...")
        segments, info = self.model.transcribe(audio_bytes, language=language, task=task, )

        full_text = []
        for segment in segments:
            full_text.append({"segmentId": segment.id,
                              "text": segment.text,
                              "timestamps": {
                                  "start": segment.start,
                                  "end": segment.end
                              }
                              })

            if segment.id % 100 == 0:
                logging.info("segment " + str(segment.id) + " transcribed")
        logging.info("Inference completed.")

        response.update({"audios": full_text})
        logging.debug(response)
        if audio_path:
            os.remove(audio_path)
        return response


if __name__ == '__main__':
    Parser = argparse.ArgumentParser(description="EndpointHandler")
    Parser.add_argument("-p", "--path")
    Parser.add_argument("-l", "--language", default="de")
    Parser.add_argument("-t", "--task", default="transcribe")
    Parser.add_argument("--type", default="video")
    Args = Parser.parse_args()

    handler = EndpointHandler()


    # if is_cdn_link(Args.path):
    #     test_inputs = Args.path
    # else:
    audio = AudioSegment.from_mp3(r"C:\Users\mbabu\AppData\Local\Temp\tmpsezkw2i5.mp3")
    buffer = io.BytesIO()
    audio.export(buffer, format="mp3")
    mp3_bytes = buffer.getvalue()
    test_inputs = base64.b64encode(mp3_bytes)

    sample_data = {
        "inputs": test_inputs,
        "language": Args.language,
        "task": Args.task,
    }

    test = handler(sample_data)
    print(test)