File size: 5,038 Bytes
789b968
 
 
 
 
 
 
 
d1b9c9e
 
789b968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import logging
from typing import Any, Dict, List, Optional

import numpy as np
import transformers

# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .modeling_ocismllama import OcisMllamaForConditionalGeneration
from .ocismllama_processing import OcisMllamaProcessor


class OcisMllamaPipeline(transformers.Pipeline):
    _load_processor = False
    _load_image_processor = False
    _load_feature_extractor = False
    _load_tokenizer = False
    def __init__(
        self,
        model: OcisMllamaForConditionalGeneration,
        tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
        audio_processor: Optional[transformers.ProcessorMixin] = None,
        image_processor: Optional[transformers.ProcessorMixin] = None,
        **kwargs
    ):
        if tokenizer is None:
            tokenizer = transformers.AutoTokenizer.from_pretrained(
                'meta-llama/Llama-3.2-11B-Vision-Instruct'
            )

        if audio_processor is None:
            audio_processor = transformers.AutoProcessor.from_pretrained(
                model.config.audio_config.audio_model_id
            )
        
        if image_processor is None:
            image_processor = transformers.AutoProcessor.from_pretrained(
                'meta-llama/Llama-3.2-11B-Vision-Instruct'
            )

        super().__init__(model=model, tokenizer=tokenizer, **kwargs)
        
        self.processor = OcisMllamaProcessor(
            audio_processor=audio_processor,
            image_processor=image_processor,
            tokenizer=tokenizer,
            stack_factor=model.config.audio_config.stack_factor,
        )

    def _sanitize_parameters(self, **kwargs):
        generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"]
        generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys}
        return {}, generation_kwargs, {}

    def preprocess(self, inputs: Dict[str, Any]):
        turns: list = inputs.get("turns", [])

        audio = inputs.get("audio", None)
        # Convert to float32 if needed.
        if isinstance(audio, np.ndarray):
            if audio.dtype == np.float64:
                audio = audio.astype(np.float32)
            elif audio.dtype == np.int16:
                audio = audio.astype(np.float32) / np.float32(32768.0)
            elif audio.dtype == np.int32:
                audio = audio.astype(np.float32) / np.float32(2147483648.0)

        if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
            prompt = inputs.get("prompt", "<|audio|>")
            if "<|audio|>" not in prompt:
                logging.warning(
                    "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
                )

                prompt += " <|audio|>"
            turns.append({"role": "user", "content": prompt})
        
        text = self.processor.tokenizer.apply_chat_template(
            turns, add_generation_prompt=True, tokenize=False
        )

        if "sampling_rate" not in inputs and audio is not None:
            logging.warning(
                "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
            )

        images = inputs.get("images", None)
        output = self.processor(
            text=[text],
            audio=audio,
            images=images,
            sampling_rate=inputs.get("sampling_rate", 16000),
        )
        if "audio_values" in output:
            output["audio_values"] = output["audio_values"].to(self.model.dtype)
        return output

    def _forward(
        self,
        model_inputs: Dict[str, Any],
        temperature: Optional[float] = None,
        max_new_tokens: Optional[int] = None,
        repetition_penalty: float = 1.0,
    ) -> List[int]:
        temperature = temperature or None
        do_sample = temperature is not None

        terminators = [self.tokenizer.eos_token_id]
        if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
            terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))

        input_len = model_inputs["input_ids"].shape[1]
        model_inputs['input_ids'][model_inputs['input_ids']==128256] = 128004
        outputs = self.model.generate(
            **model_inputs,
            do_sample=do_sample,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            eos_token_id=terminators
        )
        return outputs[0][input_len:]

    def postprocess(self, model_outputs) -> str:
        output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
        return output_text


transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
    "ocismllama-pipeline",
    pipeline_class=OcisMllamaPipeline,
    pt_model=transformers.AutoModel,
    type="multimodal",
)