|
from transformers.models.llama.modeling_llama import * |
|
from typing import List, Optional, Tuple, Union |
|
from transformers.cache_utils import Cache, DynamicCache, StaticCache |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
QuestionAnsweringModelOutput, |
|
SequenceClassifierOutputWithPast, |
|
TokenClassifierOutput, |
|
) |
|
from transformers.utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
is_flash_attn_greater_or_equal_2_10, |
|
is_torchdynamo_compiling, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
from transformers.models.llama.configuration_llama import LlamaConfig |
|
|
|
def CustomLlamaConfig(LlamaConfig): |
|
def __init__(self, quant, split_idx, *args, **kwargs): |
|
|
|
super().__init__(*args, **kwargs) |
|
self.quant = quant |
|
self.split_idx = split_idx |
|
|
|
def to_dict(self): |
|
config_dict = super().to_dict() |
|
config_dict["quant"] = self.quant |
|
config_dict["split_idx"] = self.split_idx |
|
return config_dict |