my-custom-llama3.1-8b-try2 / custom_llama_config.py
jaymie23's picture
Upload custom_llama_config.py
6a87ec5 verified
from transformers.models.llama.modeling_llama import * #LLaMAModel
from typing import List, Optional, Tuple, Union
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from transformers.models.llama.configuration_llama import LlamaConfig
def CustomLlamaConfig(LlamaConfig):
def __init__(self, quant, split_idx, *args, **kwargs):
super().__init__(*args, **kwargs)
self.quant = quant
self.split_idx = split_idx
def to_dict(self):
config_dict = super().to_dict()
config_dict["quant"] = self.quant
config_dict["split_idx"] = self.split_idx
return config_dict