whisper-large_v2_test / pipe_handler.py
slplab's picture
Create pipe_handler.py
e68ccbb
raw
history blame
3.53 kB
from typing import Dict, Any, List
from transformers import WhisperForConditionalGeneration, AutoProcessor, WhisperTokenizer, WhisperProcessor, pipeline, WhisperFeatureExtractor
import torch
import soundfile as sf
import io
class EndpointHandler:
def __init__(self, path=""):
tokenizer = WhisperTokenizer.from_pretrained('openai/whisper-large', language="korean", task='transcribe')
model = WhisperForConditionalGeneration.from_pretrained(path)
#self.tokenizer = WhisperTokenizer.from_pretrained(path)
#self.processor = WhisperProcessor.from_pretrained(path, language="korean", task='transcribe')
processor = AutoProcessor.from_pretrained(path)
#self.pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.feature_extractor, feature_extractor=processor.feature_extractor)
feature_extractor = WhisperFeatureExtractor.from_pretrained('openai/whisper-large')
self.pipe = pipeline(task='automatic-speech-recognition', model=path)
# Move model to device
# self.model.to(device)
def __call__(self, data: Any) -> List[Dict[str, str]]:
print('==========NEW PROCESS=========')
transcription = pipeline(task="automatic-speech-recognition", model="vasista22/whisper-kannada-tiny", chunk_length_s=30, device=device)
transcription.model.config.forced_decoder_ids = transcribe.tokenizer.get_decoder_prompt_ids(language="ko", task="transcribe")
result = transcription(data['inputs'])
#print(f"{data}")
#inputs = data.pop("inputs", data)
#print(f'1. inputs: {inputs}')
#inputs, _ = sf.read(io.BytesIO(data['inputs']))
#inputs, _ = sf.read(data['inputs'])
#print(f'2. inputs: {inputs}')
# input_features = self.feature_extractor(inputs, sampling_rate=16000).input_features[0]
# #print(f'3. input_features: {input_features}')
# input_features_tensor = torch.tensor(input_features).unsqueeze(0)
# input_ids = self.model.generate(input_features_tensor)
# #(f'4. input_ids: {input_ids}')
# transcription = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
# #inputs, _ = torchaudio.load(inputs, normalize=True)
# #input_features = self.processor.feature_extractor(inputs, sampling_rate=16000).input_features[0]
#input_ids = self.processor.tokenizer(input_features, return_tensors="pt").input_ids
#generated_ids = self.model.generate(input_ids)
# #transcription = self.pipe(inputs, generate_kwargs = {"task":"transcribe", "language":"<|ko|>"})
# #transcription = self.pipe(inputs)
# #print(input)
# inputs = self.processor(inputs, retun_tensors="pt")
# #input_features = {key: value.to(device) for key, value in input_features.items()}
# input_features = inputs.input_features
# generated_ids = self.model.generate(input_features)
# #generated_ids = self.model.generate(inputs=input_features)
# #self.model.generate = partial(self.model.generate, language="korean", task="transcribe")
# #generated_ids = self.model.generate(inputs = input_features)
#transcription = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
#transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return result