File size: 1,059 Bytes
6a87ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers.models.llama.modeling_llama import * #LLaMAModel
from typing import List, Optional, Tuple, Union
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_greater_or_equal_2_10,
    is_torchdynamo_compiling,
    logging,
    replace_return_docstrings,
)
from transformers.models.llama.configuration_llama import LlamaConfig

def CustomLlamaConfig(LlamaConfig):
    def __init__(self, quant, split_idx, *args, **kwargs):

        super().__init__(*args, **kwargs)
        self.quant = quant
        self.split_idx = split_idx

    def to_dict(self):
        config_dict = super().to_dict()
        config_dict["quant"] = self.quant
        config_dict["split_idx"] = self.split_idx
        return config_dict