File size: 3,235 Bytes
3d36ead
9d67e70
3d36ead
880d01b
9d67e70
691565a
3d36ead
 
 
 
 
 
 
 
 
a49a698
 
 
3d36ead
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d67e70
e06292f
79b6e07
3d36ead
 
 
ef49102
e06292f
 
3d36ead
a443875
ef49102
3d36ead
 
9d67e70
 
 
 
 
 
 
 
 
 
 
 
691565a
d62e007
 
 
 
9f147d5
691565a
 
 
 
9d67e70
 
 
 
 
3d36ead
9d67e70
3d36ead
9d67e70
 
 
3d36ead
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from typing import Dict, List, Any, Literal, Optional, Tuple
import torch
import logging
from pydantic_settings import BaseSettings
from pydantic import field_validator

class EndpointHandler():
    def __init__(self, path=""):
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

        model_id = "openai/whisper-large-v3-turbo"

        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_id, torch_dtype=torch_dtype,
            low_cpu_mem_usage=True, use_safetensors=True,
            attn_implementation="sdpa"
        )
        model.to(device)

        processor = AutoProcessor.from_pretrained(model_id)

        self.pipe = pipeline(
            "automatic-speech-recognition",
            model=model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            torch_dtype=torch_dtype,
            device=device,
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str`)
            parameters (:obj: `Any`)
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        inputs = data.pop("inputs", data)
        whisper_parameter_handler = WhisperParameterHandler()
        logging.info(whisper_parameter_handler.model_dump(exclude_none=True, exclude=["return_timestamps"]))


        # run normal prediction
        prediction = self.pipe(
            inputs,
            return_timestamps=whisper_parameter_handler.return_timestamps,
            generate_kwargs=whisper_parameter_handler.model_dump(exclude_none=True, exclude=["return_timestamps"])
            )
        logging.info(prediction)
        logging.info(prediction['chunks'])
        return prediction
    

class WhisperParameterHandler(BaseSettings):
    language: Optional[str] = None  # Optional fields default to None
    max_new_tokens: Optional[int] = None
    num_beams: Optional[int] = None
    condition_on_prev_tokens: Optional[bool] = None
    compression_ratio_threshold: Optional[float] = None
    temperature: Optional[Tuple[float, ...]] = None  # Optional Tuple
    logprob_threshold: Optional[float] = None
    no_speech_threshold: Optional[float] = None
    return_timestamps: Optional[Literal["word", True]] = None

    @field_validator("return_timestamps", mode="before")
    def cannonize_timestamps(cls, value: Optional[str]):
        if value is None:
            return None
        if value.lower() == "true":
            logging.info("return_timestamps == 'True'")
            return True
        
        return value
    
    model_config = {
        "env_prefix": "WHISPER_KWARGS_",
        "case_sensitive": False,
    }

    def to_kwargs(self):
        """Convert object attributes to kwargs dict, excluding None values."""
        return {
            key: value
            for key, value in self.model_dump().items()  # Use model_dump for accurate representation
            if value is not None  # Exclude None values
        }