File size: 5,038 Bytes
789b968 d1b9c9e 789b968 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import logging
from typing import Any, Dict, List, Optional
import numpy as np
import transformers
# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .modeling_ocismllama import OcisMllamaForConditionalGeneration
from .ocismllama_processing import OcisMllamaProcessor
class OcisMllamaPipeline(transformers.Pipeline):
_load_processor = False
_load_image_processor = False
_load_feature_extractor = False
_load_tokenizer = False
def __init__(
self,
model: OcisMllamaForConditionalGeneration,
tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
audio_processor: Optional[transformers.ProcessorMixin] = None,
image_processor: Optional[transformers.ProcessorMixin] = None,
**kwargs
):
if tokenizer is None:
tokenizer = transformers.AutoTokenizer.from_pretrained(
'meta-llama/Llama-3.2-11B-Vision-Instruct'
)
if audio_processor is None:
audio_processor = transformers.AutoProcessor.from_pretrained(
model.config.audio_config.audio_model_id
)
if image_processor is None:
image_processor = transformers.AutoProcessor.from_pretrained(
'meta-llama/Llama-3.2-11B-Vision-Instruct'
)
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
self.processor = OcisMllamaProcessor(
audio_processor=audio_processor,
image_processor=image_processor,
tokenizer=tokenizer,
stack_factor=model.config.audio_config.stack_factor,
)
def _sanitize_parameters(self, **kwargs):
generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"]
generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys}
return {}, generation_kwargs, {}
def preprocess(self, inputs: Dict[str, Any]):
turns: list = inputs.get("turns", [])
audio = inputs.get("audio", None)
# Convert to float32 if needed.
if isinstance(audio, np.ndarray):
if audio.dtype == np.float64:
audio = audio.astype(np.float32)
elif audio.dtype == np.int16:
audio = audio.astype(np.float32) / np.float32(32768.0)
elif audio.dtype == np.int32:
audio = audio.astype(np.float32) / np.float32(2147483648.0)
if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
prompt = inputs.get("prompt", "<|audio|>")
if "<|audio|>" not in prompt:
logging.warning(
"Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
)
prompt += " <|audio|>"
turns.append({"role": "user", "content": prompt})
text = self.processor.tokenizer.apply_chat_template(
turns, add_generation_prompt=True, tokenize=False
)
if "sampling_rate" not in inputs and audio is not None:
logging.warning(
"No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
)
images = inputs.get("images", None)
output = self.processor(
text=[text],
audio=audio,
images=images,
sampling_rate=inputs.get("sampling_rate", 16000),
)
if "audio_values" in output:
output["audio_values"] = output["audio_values"].to(self.model.dtype)
return output
def _forward(
self,
model_inputs: Dict[str, Any],
temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: float = 1.0,
) -> List[int]:
temperature = temperature or None
do_sample = temperature is not None
terminators = [self.tokenizer.eos_token_id]
if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
input_len = model_inputs["input_ids"].shape[1]
model_inputs['input_ids'][model_inputs['input_ids']==128256] = 128004
outputs = self.model.generate(
**model_inputs,
do_sample=do_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
eos_token_id=terminators
)
return outputs[0][input_len:]
def postprocess(self, model_outputs) -> str:
output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
return output_text
transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
"ocismllama-pipeline",
pipeline_class=OcisMllamaPipeline,
pt_model=transformers.AutoModel,
type="multimodal",
)
|