sergeipetrov commited on
Commit
029e0bd
1 Parent(s): 632d04c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -22
handler.py CHANGED
@@ -6,66 +6,56 @@ from pyannote.audio import Pipeline
6
  from transformers import pipeline, AutoModelForCausalLM
7
  from diarization_utils import diarize
8
  from huggingface_hub import HfApi
9
- from transformers.pipelines.audio_utils import ffmpeg_read
10
- from pydantic import Json, BaseModel, ValidationError
11
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
- class InferenceConfig(BaseModel):
17
- task: Literal["transcribe", "translate"] = "transcribe"
18
- batch_size: int = 24
19
- assisted: bool = False
20
- chunk_length_s: int = 30
21
- sampling_rate: int = 16000
22
- language: Optional[str] = None
23
- num_speakers: Optional[int] = None
24
- min_speakers: Optional[int] = None
25
- max_speakers: Optional[int] = None
26
-
27
-
28
  class EndpointHandler():
29
- def __init__(self):
30
 
31
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
32
  logger.info(f"Using device: {device.type}")
33
  torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
34
 
35
  self.assistant_model = AutoModelForCausalLM.from_pretrained(
36
- os.getenv("ASSISTANT_MODEL"),
37
  torch_dtype=torch_dtype,
38
  low_cpu_mem_usage=True,
39
  use_safetensors=True
40
- ) if os.getenv("ASSISTANT_MODEL") else None
41
 
42
  if self.assistant_model:
43
  self.assistant_model.to(device)
44
 
45
  self.asr_pipeline = pipeline(
46
  "automatic-speech-recognition",
47
- model=os.getenv("ASR_MODEL"),
48
  torch_dtype=torch_dtype,
49
  device=device
50
  )
51
 
52
- if os.getenv("DIARIZATION_MODEL"):
53
  # diarization pipeline doesn't raise if there is no token
54
  HfApi().whoami(model_settings.hf_token)
55
  self.diarization_pipeline = Pipeline.from_pretrained(
56
- checkpoint_path=os.getenv("DIARIZATION_MODEL"),
57
- use_auth_token=os.getenv("HF_TOKEN"),
58
  )
59
  self.diarization_pipeline.to(device)
60
  else:
61
  self.diarization_pipeline = None
 
62
 
63
  async def __call__(self, file, parameters):
64
  try:
65
  parameters = InferenceConfig(**parameters)
66
  except ValidationError as e:
67
  logger.error(f"Error validating parameters: {e}")
68
- raise ValidationError(f"Error validating parameters: {e}")
69
 
70
  logger.info(f"inference parameters: {parameters}")
71
 
 
6
  from transformers import pipeline, AutoModelForCausalLM
7
  from diarization_utils import diarize
8
  from huggingface_hub import HfApi
9
+ from pydantic import ValidationError
10
+ from starlette.exceptions import HTTPException
11
 
12
+ from config import model_settings, InferenceConfig
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class EndpointHandler():
18
+ def __init__(self, model_settings):
19
 
20
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
21
  logger.info(f"Using device: {device.type}")
22
  torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
23
 
24
  self.assistant_model = AutoModelForCausalLM.from_pretrained(
25
+ model_settings.assistant_model,
26
  torch_dtype=torch_dtype,
27
  low_cpu_mem_usage=True,
28
  use_safetensors=True
29
+ ) if model_settings.assistant_model else None
30
 
31
  if self.assistant_model:
32
  self.assistant_model.to(device)
33
 
34
  self.asr_pipeline = pipeline(
35
  "automatic-speech-recognition",
36
+ model=model_settings.asr_model,
37
  torch_dtype=torch_dtype,
38
  device=device
39
  )
40
 
41
+ if model_settings.diarization_model:
42
  # diarization pipeline doesn't raise if there is no token
43
  HfApi().whoami(model_settings.hf_token)
44
  self.diarization_pipeline = Pipeline.from_pretrained(
45
+ checkpoint_path=model_settings.diarization_model,
46
+ use_auth_token=model_settings.hf_token,
47
  )
48
  self.diarization_pipeline.to(device)
49
  else:
50
  self.diarization_pipeline = None
51
+
52
 
53
  async def __call__(self, file, parameters):
54
  try:
55
  parameters = InferenceConfig(**parameters)
56
  except ValidationError as e:
57
  logger.error(f"Error validating parameters: {e}")
58
+ raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}")
59
 
60
  logger.info(f"inference parameters: {parameters}")
61