File size: 2,712 Bytes
a8eab90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import base64
import json
import os
from io import StringIO
from typing import Dict, Any
from transformers import pipeline
class EndpointHandler:
def __init__(self, asr_model_path: str = "./whisper-large-v2"):
# Create an ASR pipeline using the model located in the specified directory
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model = asr_model_path,
)
def __call__(self, data: Dict[str, Any]) -> str:
json_data = json.loads(data)
if "audio_data" not in json_data.keys():
raise Exception("Request must contain a top-level key named 'audio_data'")
# Get the audio data from the input
audio_data = json_data["audio_data"]
language = json_data["language"]
# Decode the binary audio data if it's provided as a base64 string
if isinstance(audio_data, str):
audio_data = base64.b64decode(audio_data)
# Process the audio data with the ASR pipeline
transcription = self.asr_pipeline(
audio_data,
return_timestamps=False,
chunk_length_s=30,
batch_size=8,
max_length=10000,
max_new_tokens=10000,
generate_kwargs={"task": "transcribe", "language": "<|language|>"}
)
# Convert the transcription to JSON
result = StringIO()
json.dump(transcription, result)
return result.getvalue()
def init():
global asr_pipeline
# Set the path to the directory where the model is stored
model_path = os.getenv("AZUREML_MODEL_DIR", "./whisper-large-v2")
# Create an ASR pipeline using the model located in the specified directory
asr_pipeline = pipeline(
"automatic-speech-recognition",
model = model_path,
)
def run(raw_data):
json_data = json.loads(raw_data)
if "audio_data" not in json_data.keys():
raise Exception("Request must contain a top level key named 'audio_data'")
# Get the audio data from the input
audio_data = json_data["audio_data"]
# Decode the binary audio data if it's provided as a base64 string
if isinstance(audio_data, str):
import base64
audio_data = base64.b64decode(audio_data)
# Process the audio data with the ASR pipeline
transcription = asr_pipeline(
audio_data,
return_timestamps = False,
chunk_length_s = 30,
batch_size = 8,
max_new_tokens = 1000,
generate_kwargs = {"task": "transcribe", "language": "<|de|>"}
)
# Convert the transcription to JSON
result = StringIO()
json.dump(transcription, result)
return result.getvalue()
|