import logging from typing import Any, Dict, List, Optional import transformers # We must use relative import in this directory to allow uploading to HF Hub # Even "from . import X" pattern doesn't work (undocumented and unclear why) from .ultravox_model import UltravoxModel from .ultravox_processing import UltravoxProcessor class UltravoxPipeline(transformers.Pipeline): def __init__( self, model: UltravoxModel, tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None, audio_processor: Optional[transformers.ProcessorMixin] = None, **kwargs ): if tokenizer is None: tokenizer = transformers.AutoTokenizer.from_pretrained( model.config._name_or_path ) if audio_processor is None: audio_processor = transformers.AutoProcessor.from_pretrained( model.config.audio_model_id or model.config.audio_config._name_or_path ) self.processor = UltravoxProcessor( audio_processor=audio_processor, tokenizer=tokenizer, stack_factor=model.config.stack_factor, ) super().__init__(model=model, tokenizer=tokenizer, **kwargs) def _sanitize_parameters(self, **kwargs): generation_kwargs = {} if "temperature" in kwargs: generation_kwargs["temperature"] = kwargs["temperature"] if "max_new_tokens" in kwargs: generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"] if "repetition_penalty" in kwargs: generation_kwargs["repetition_penalty"] = kwargs["repetition_penalty"] return {}, generation_kwargs, {} def preprocess(self, inputs: Dict[str, Any]): if "turns" in inputs: turns = inputs["turns"] else: prompt = inputs.get("prompt", "<|audio|>") if "<|audio|>" not in prompt: logging.warning( "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt." ) prompt += " <|audio|>" turns = [{"role": "user", "content": prompt}] text = self.processor.tokenizer.apply_chat_template(turns, tokenize=False) # TODO: allow text-only mode? assert "audio" in inputs, "Audio input is required" if "sampling_rate" not in inputs: logging.warning( "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate." ) output = self.processor( text=text, audio=inputs["audio"], sampling_rate=inputs.get("sampling_rate", 16000), ) if "audio_values" in output: output["audio_values"] = output["audio_values"].to(self.model.dtype) return output def _forward( self, model_inputs: Dict[str, Any], temperature: Optional[float] = None, max_new_tokens: Optional[int] = None, repetition_penalty: float = 1.1, ) -> List[int]: temperature = temperature or None do_sample = temperature is not None terminators = [self.tokenizer.eos_token_id] if "<|eot_id|>" in self.tokenizer.added_tokens_encoder: terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>")) input_len = model_inputs["input_ids"].shape[1] outputs = self.model.generate( **model_inputs, do_sample=do_sample, temperature=temperature, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, eos_token_id=terminators ) return outputs[0][input_len:] def postprocess(self, model_outputs) -> str: output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True) return output_text transformers.pipelines.PIPELINE_REGISTRY.register_pipeline( "ultravox-pipeline", pipeline_class=UltravoxPipeline, pt_model=transformers.AutoModel, type="multimodal", )