Spaces:
Runtime error
Runtime error
from types import MethodType | |
from functools import partial | |
import self_extend_patch as SE | |
def modify_method_of_instance(instance, target_class_name, target_method_name, new_method, visited_instances=None): | |
""" | |
This function modifies the method of an instance of a model class. | |
It's part from chat-GPT. | |
It will replace the method with the new method. | |
Currently, we only use this function to modify the attention method of a model. Do not test it further. | |
instance: | |
instance of a model to modify. | |
target_class_name: | |
name of the attention class to modify. E.g. 'LlamaAttention', 'GPTNeoXAttention', etc. | |
new_method: new method to replace the original method. E.g. 'self_extend_forward'. | |
It should include a parameter 'self' to be binded to the instance. | |
""" | |
target_found = False | |
if visited_instances is None: | |
visited_instances = set() | |
# Unique identifier for the instance (using id() since object's id is unique) | |
instance_id = id(instance) | |
if instance_id in visited_instances: | |
target_found = False | |
return target_found | |
# Add the instance to the already_visited set | |
visited_instances.add(instance_id) | |
# Check if this instance is of the target class | |
if instance.__class__.__name__ == target_class_name: | |
bond_method = MethodType(new_method, instance) | |
setattr(instance, target_method_name, bond_method) | |
target_found = True | |
return target_found | |
elif hasattr(instance, '__dict__'): | |
for attr_name, attr_value in instance.__dict__.items(): | |
if isinstance(attr_value, object) and not isinstance(attr_value, (list, tuple, dict, set)): | |
_found = modify_method_of_instance(attr_value, target_class_name, target_method_name, new_method, visited_instances) | |
if _found: | |
target_found = True | |
elif isinstance(attr_value, (list, tuple)): | |
for item in attr_value: | |
if isinstance(item, object): | |
_found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances) | |
if _found: | |
target_found = True | |
# If attribute value is a dictionary, iterate over its values and recurse | |
# E.g, for a ModuleList, its moudels are stored in a dictionary: ._modules | |
elif isinstance(attr_value, dict): | |
for key, value in attr_value.items(): | |
if isinstance(value, object): | |
_found = modify_method_of_instance(value, target_class_name, target_method_name, new_method, visited_instances) | |
if _found: | |
target_found = True | |
# If attribute value is a set, iterate and recurse | |
elif isinstance(attr_value, set): | |
for item in attr_value: | |
if isinstance(item, object): | |
_found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances) | |
if _found: | |
target_found = True | |
return target_found | |
def apply(loaded_model, group_size, window_size, enable_flash_attention=False, scale_base=-1, flash_attention_impl="triton"): | |
''' | |
loaded_model: | |
model to apply the self-attention extension. | |
group_size: | |
group size for the self-attention extension. | |
window_size: | |
window size for the self-attention extension. | |
scale_base: | |
base for the scale, equal to pretraining length. | |
e.g. 4096 for Llama, 8192 for Gemma | |
Two recommended scale factor: | |
yarn: https://arxiv.org/abs/2309.00071 | |
log: https://arxiv.org/abs/2202.12172 ; https://kexue.fm/archives/8823 | |
This is helpful while retrieving a long sequence (e.g a long passkey). | |
But on real-world data, the impact is minor. (e.g. on LongBench, LEval). | |
The reported results in our paper does not use this scale except for long passkey retrieval. | |
''' | |
arch_name = loaded_model.__class__.__name__ | |
if 'Llama' in arch_name: | |
if enable_flash_attention: | |
if flash_attention_impl == "flash_attn": | |
self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward, | |
group_size_1=group_size, | |
group_size_2=window_size, | |
scale_base=scale_base) | |
modifed_1 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size) | |
modifed_2 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward) | |
print("Using flash_attn flash self_extend!!") | |
if (not modifed_1) or (not modifed_2): | |
raise Exception(f"Failed to modify the attention method of {arch_name}") | |
elif flash_attention_impl == "triton": | |
self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward_triton, | |
group_size_1=group_size, | |
group_size_2=window_size, | |
scale_base=scale_base) | |
modifed = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward) | |
print("Using triton flash self_extend!!") | |
if (not modifed): | |
raise Exception(f"Failed to modify the attention method of {arch_name}") | |
else: | |
raise Exception(f"Need to set the flash_attention_impl to 'flash_attn' or 'triton'.") | |
else: | |
self_extend_attention_forward = partial(SE.Llama.self_extend_forward, | |
group_size_1=group_size, | |
group_size_2=window_size, | |
scale_base=scale_base) | |
# after the default version of attention in 4.36 is LlamaSpdaAttention, but in before 4,36 or in 4.38, it is LlamaAttention | |
# print("loaded_model", loaded_model) | |
modifed_2 = modify_method_of_instance(loaded_model, "LlamaAttention", "forward", self_extend_attention_forward) | |
if not modifed_2: | |
raise Exception(f"Failed to modify the attention method of {arch_name}") | |
elif 'Mistral' in arch_name: | |
# Mistral shares the same architecture with Llama, so the implementation should be exchangable. | |
if enable_flash_attention: | |
self_extend_attention_forward = partial(SE.Mistral.flash_self_extend_forward, | |
group_size_1=group_size, | |
group_size_2=window_size, | |
scale_base=scale_base) | |
modifed_1 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size) | |
modifed_2 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "forward", self_extend_attention_forward) | |
if (not modifed_1) or (not modifed_2): | |
raise Exception(f"Failed to modify the attention method of {arch_name}") | |
else: | |
self_extend_attention_forward = partial(SE.Mistral.self_extend_forward, | |
group_size_1=group_size, | |
group_size_2=window_size, | |
scale_base=scale_base) | |
modifed_2 = modify_method_of_instance(loaded_model, "MistralAttention", "forward", self_extend_attention_forward) | |
if not modifed_2: | |
raise Exception(f"Failed to modify the attention method of {arch_name}") | |
elif 'Gemma' in arch_name: | |
if enable_flash_attention: | |
self_extend_attention_forward = partial(SE.Gemma.flash_self_extend_forward, | |
group_size_1=group_size, | |
group_size_2=window_size, | |
scale_base=scale_base) | |
modifed_1 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size) | |
modifed_2 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "forward", self_extend_attention_forward) | |
if (not modifed_1) or (not modifed_2): | |
raise Exception(f"Failed to modify the attention method of {arch_name}") | |
else: | |
self_extend_attention_forward = partial(SE.Gemma.self_extend_forward, | |
group_size_1=group_size, | |
group_size_2=window_size, | |
scale_base=scale_base) | |
modifed_2= modify_method_of_instance(loaded_model, "GemmaAttention", "forward", self_extend_attention_forward) | |
if not modifed_2: | |
raise Exception(f"Failed to modify the attention method of {arch_name}") | |
elif 'Qwen2' in arch_name: | |
if enable_flash_attention: | |
self_extend_attention_forward = partial(SE.Qwen2.flash_self_extend_forward, | |
group_size_1=group_size, | |
group_size_2=window_size, | |
scale_base=scale_base) | |
modifed_1 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size) | |
modifed_2 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "forward", self_extend_attention_forward) | |
if (not modifed_1) or (not modifed_2): | |
raise Exception(f"Failed to modify the attention method of {arch_name}") | |
else: | |
self_extend_attention_forward = partial(SE.Qwen2.self_extend_forward, | |
group_size_1=group_size, | |
group_size_2=window_size, | |
scale_base=scale_base) | |
modifed_2 = modify_method_of_instance(loaded_model, "Qwen2Attention", "forward", self_extend_attention_forward) | |
if not modifed_2: | |
raise Exception(f"Failed to modify the attention method of {arch_name}") | |
elif 'Phi' in arch_name: | |
if enable_flash_attention: | |
self_extend_attention_forward = partial(SE.Phi.flash_self_extend_forward, | |
group_size_1=group_size, | |
group_size_2=window_size, | |
scale_base=scale_base) | |
modifed_1 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size) | |
modifed_2 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "forward", self_extend_attention_forward) | |
if (not modifed_1) or (not modifed_2): | |
raise Exception(f"Failed to modify the attention method of {arch_name}") | |
else: | |
self_extend_attention_forward = partial(SE.Phi.self_extend_forward, | |
group_size_1=group_size, | |
group_size_2=window_size, | |
scale_base=scale_base) | |
modifed_2 = modify_method_of_instance(loaded_model, "PhiAttention", "forward", self_extend_attention_forward) | |
if not modifed_2: | |
raise Exception(f"Failed to modify the attention method of {arch_name}") | |
else: | |
raise NotImplementedError | |