File size: 3,235 Bytes
3d36ead 9d67e70 3d36ead 880d01b 9d67e70 691565a 3d36ead a49a698 3d36ead 9d67e70 e06292f 79b6e07 3d36ead ef49102 e06292f 3d36ead a443875 ef49102 3d36ead 9d67e70 691565a d62e007 9f147d5 691565a 9d67e70 3d36ead 9d67e70 3d36ead 9d67e70 3d36ead |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from typing import Dict, List, Any, Literal, Optional, Tuple
import torch
import logging
from pydantic_settings import BaseSettings
from pydantic import field_validator
class EndpointHandler():
def __init__(self, path=""):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3-turbo"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype,
low_cpu_mem_usage=True, use_safetensors=True,
attn_implementation="sdpa"
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
self.pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
parameters (:obj: `Any`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
whisper_parameter_handler = WhisperParameterHandler()
logging.info(whisper_parameter_handler.model_dump(exclude_none=True, exclude=["return_timestamps"]))
# run normal prediction
prediction = self.pipe(
inputs,
return_timestamps=whisper_parameter_handler.return_timestamps,
generate_kwargs=whisper_parameter_handler.model_dump(exclude_none=True, exclude=["return_timestamps"])
)
logging.info(prediction)
logging.info(prediction['chunks'])
return prediction
class WhisperParameterHandler(BaseSettings):
language: Optional[str] = None # Optional fields default to None
max_new_tokens: Optional[int] = None
num_beams: Optional[int] = None
condition_on_prev_tokens: Optional[bool] = None
compression_ratio_threshold: Optional[float] = None
temperature: Optional[Tuple[float, ...]] = None # Optional Tuple
logprob_threshold: Optional[float] = None
no_speech_threshold: Optional[float] = None
return_timestamps: Optional[Literal["word", True]] = None
@field_validator("return_timestamps", mode="before")
def cannonize_timestamps(cls, value: Optional[str]):
if value is None:
return None
if value.lower() == "true":
logging.info("return_timestamps == 'True'")
return True
return value
model_config = {
"env_prefix": "WHISPER_KWARGS_",
"case_sensitive": False,
}
def to_kwargs(self):
"""Convert object attributes to kwargs dict, excluding None values."""
return {
key: value
for key, value in self.model_dump().items() # Use model_dump for accurate representation
if value is not None # Exclude None values
}
|