ghost-8b-beta-128k / SelfExtend.py
lamhieu's picture
chore: trigger rebuild app
2bac78a
raw
history blame
12.4 kB
from types import MethodType
from functools import partial
import self_extend_patch as SE
def modify_method_of_instance(instance, target_class_name, target_method_name, new_method, visited_instances=None):
"""
This function modifies the method of an instance of a model class.
It's part from chat-GPT.
It will replace the method with the new method.
Currently, we only use this function to modify the attention method of a model. Do not test it further.
instance:
instance of a model to modify.
target_class_name:
name of the attention class to modify. E.g. 'LlamaAttention', 'GPTNeoXAttention', etc.
new_method: new method to replace the original method. E.g. 'self_extend_forward'.
It should include a parameter 'self' to be binded to the instance.
"""
target_found = False
if visited_instances is None:
visited_instances = set()
# Unique identifier for the instance (using id() since object's id is unique)
instance_id = id(instance)
if instance_id in visited_instances:
target_found = False
return target_found
# Add the instance to the already_visited set
visited_instances.add(instance_id)
# Check if this instance is of the target class
if instance.__class__.__name__ == target_class_name:
bond_method = MethodType(new_method, instance)
setattr(instance, target_method_name, bond_method)
target_found = True
return target_found
elif hasattr(instance, '__dict__'):
for attr_name, attr_value in instance.__dict__.items():
if isinstance(attr_value, object) and not isinstance(attr_value, (list, tuple, dict, set)):
_found = modify_method_of_instance(attr_value, target_class_name, target_method_name, new_method, visited_instances)
if _found:
target_found = True
elif isinstance(attr_value, (list, tuple)):
for item in attr_value:
if isinstance(item, object):
_found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
if _found:
target_found = True
# If attribute value is a dictionary, iterate over its values and recurse
# E.g, for a ModuleList, its moudels are stored in a dictionary: ._modules
elif isinstance(attr_value, dict):
for key, value in attr_value.items():
if isinstance(value, object):
_found = modify_method_of_instance(value, target_class_name, target_method_name, new_method, visited_instances)
if _found:
target_found = True
# If attribute value is a set, iterate and recurse
elif isinstance(attr_value, set):
for item in attr_value:
if isinstance(item, object):
_found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
if _found:
target_found = True
return target_found
def apply(loaded_model, group_size, window_size, enable_flash_attention=False, scale_base=-1, flash_attention_impl="triton"):
'''
loaded_model:
model to apply the self-attention extension.
group_size:
group size for the self-attention extension.
window_size:
window size for the self-attention extension.
scale_base:
base for the scale, equal to pretraining length.
e.g. 4096 for Llama, 8192 for Gemma
Two recommended scale factor:
yarn: https://arxiv.org/abs/2309.00071
log: https://arxiv.org/abs/2202.12172 ; https://kexue.fm/archives/8823
This is helpful while retrieving a long sequence (e.g a long passkey).
But on real-world data, the impact is minor. (e.g. on LongBench, LEval).
The reported results in our paper does not use this scale except for long passkey retrieval.
'''
arch_name = loaded_model.__class__.__name__
if 'Llama' in arch_name:
if enable_flash_attention:
if flash_attention_impl == "flash_attn":
self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_1 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
modifed_2 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
print("Using flash_attn flash self_extend!!")
if (not modifed_1) or (not modifed_2):
raise Exception(f"Failed to modify the attention method of {arch_name}")
elif flash_attention_impl == "triton":
self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward_triton,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
print("Using triton flash self_extend!!")
if (not modifed):
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
raise Exception(f"Need to set the flash_attention_impl to 'flash_attn' or 'triton'.")
else:
self_extend_attention_forward = partial(SE.Llama.self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
# after the default version of attention in 4.36 is LlamaSpdaAttention, but in before 4,36 or in 4.38, it is LlamaAttention
# print("loaded_model", loaded_model)
modifed_2 = modify_method_of_instance(loaded_model, "LlamaAttention", "forward", self_extend_attention_forward)
if not modifed_2:
raise Exception(f"Failed to modify the attention method of {arch_name}")
elif 'Mistral' in arch_name:
# Mistral shares the same architecture with Llama, so the implementation should be exchangable.
if enable_flash_attention:
self_extend_attention_forward = partial(SE.Mistral.flash_self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_1 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
modifed_2 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "forward", self_extend_attention_forward)
if (not modifed_1) or (not modifed_2):
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
self_extend_attention_forward = partial(SE.Mistral.self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_2 = modify_method_of_instance(loaded_model, "MistralAttention", "forward", self_extend_attention_forward)
if not modifed_2:
raise Exception(f"Failed to modify the attention method of {arch_name}")
elif 'Gemma' in arch_name:
if enable_flash_attention:
self_extend_attention_forward = partial(SE.Gemma.flash_self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_1 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
modifed_2 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "forward", self_extend_attention_forward)
if (not modifed_1) or (not modifed_2):
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
self_extend_attention_forward = partial(SE.Gemma.self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_2= modify_method_of_instance(loaded_model, "GemmaAttention", "forward", self_extend_attention_forward)
if not modifed_2:
raise Exception(f"Failed to modify the attention method of {arch_name}")
elif 'Qwen2' in arch_name:
if enable_flash_attention:
self_extend_attention_forward = partial(SE.Qwen2.flash_self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_1 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
modifed_2 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "forward", self_extend_attention_forward)
if (not modifed_1) or (not modifed_2):
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
self_extend_attention_forward = partial(SE.Qwen2.self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_2 = modify_method_of_instance(loaded_model, "Qwen2Attention", "forward", self_extend_attention_forward)
if not modifed_2:
raise Exception(f"Failed to modify the attention method of {arch_name}")
elif 'Phi' in arch_name:
if enable_flash_attention:
self_extend_attention_forward = partial(SE.Phi.flash_self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_1 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
modifed_2 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "forward", self_extend_attention_forward)
if (not modifed_1) or (not modifed_2):
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
self_extend_attention_forward = partial(SE.Phi.self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_2 = modify_method_of_instance(loaded_model, "PhiAttention", "forward", self_extend_attention_forward)
if not modifed_2:
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
raise NotImplementedError