|
import base64 |
|
import json |
|
import os |
|
from io import StringIO |
|
from typing import Dict, Any |
|
|
|
from transformers import pipeline |
|
|
|
|
|
class EndpointHandler: |
|
|
|
def __init__(self, asr_model_path: str = "./whisper-large-v2"): |
|
|
|
self.asr_pipeline = pipeline( |
|
"automatic-speech-recognition", |
|
model = asr_model_path, |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> str: |
|
|
|
json_data = json.loads(data) |
|
if "audio_data" not in json_data.keys(): |
|
raise Exception("Request must contain a top-level key named 'audio_data'") |
|
|
|
|
|
audio_data = json_data["audio_data"] |
|
language = json_data["language"] |
|
|
|
|
|
if isinstance(audio_data, str): |
|
audio_data = base64.b64decode(audio_data) |
|
|
|
|
|
transcription = self.asr_pipeline( |
|
audio_data, |
|
return_timestamps=False, |
|
chunk_length_s=30, |
|
batch_size=8, |
|
max_length=10000, |
|
max_new_tokens=10000, |
|
generate_kwargs={"task": "transcribe", "language": "<|language|>"} |
|
) |
|
|
|
|
|
result = StringIO() |
|
json.dump(transcription, result) |
|
|
|
return result.getvalue() |
|
|
|
def init(): |
|
global asr_pipeline |
|
|
|
model_path = os.getenv("AZUREML_MODEL_DIR", "./whisper-large-v2") |
|
|
|
|
|
asr_pipeline = pipeline( |
|
"automatic-speech-recognition", |
|
model = model_path, |
|
) |
|
|
|
|
|
def run(raw_data): |
|
json_data = json.loads(raw_data) |
|
if "audio_data" not in json_data.keys(): |
|
raise Exception("Request must contain a top level key named 'audio_data'") |
|
|
|
|
|
audio_data = json_data["audio_data"] |
|
|
|
|
|
if isinstance(audio_data, str): |
|
import base64 |
|
audio_data = base64.b64decode(audio_data) |
|
|
|
|
|
transcription = asr_pipeline( |
|
audio_data, |
|
return_timestamps = False, |
|
chunk_length_s = 30, |
|
batch_size = 8, |
|
max_new_tokens = 1000, |
|
generate_kwargs = {"task": "transcribe", "language": "<|de|>"} |
|
) |
|
|
|
|
|
result = StringIO() |
|
json.dump(transcription, result) |
|
|
|
return result.getvalue() |
|
|