File size: 12,381 Bytes
7a58a7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from types import MethodType
from functools import partial
import self_extend_patch as SE
def modify_method_of_instance(instance, target_class_name, target_method_name, new_method, visited_instances=None):
"""
This function modifies the method of an instance of a model class.
It's part from chat-GPT.
It will replace the method with the new method.
Currently, we only use this function to modify the attention method of a model. Do not test it further.
instance:
instance of a model to modify.
target_class_name:
name of the attention class to modify. E.g. 'LlamaAttention', 'GPTNeoXAttention', etc.
new_method: new method to replace the original method. E.g. 'self_extend_forward'.
It should include a parameter 'self' to be binded to the instance.
"""
target_found = False
if visited_instances is None:
visited_instances = set()
# Unique identifier for the instance (using id() since object's id is unique)
instance_id = id(instance)
if instance_id in visited_instances:
target_found = False
return target_found
# Add the instance to the already_visited set
visited_instances.add(instance_id)
# Check if this instance is of the target class
if instance.__class__.__name__ == target_class_name:
bond_method = MethodType(new_method, instance)
setattr(instance, target_method_name, bond_method)
target_found = True
return target_found
elif hasattr(instance, '__dict__'):
for attr_name, attr_value in instance.__dict__.items():
if isinstance(attr_value, object) and not isinstance(attr_value, (list, tuple, dict, set)):
_found = modify_method_of_instance(attr_value, target_class_name, target_method_name, new_method, visited_instances)
if _found:
target_found = True
elif isinstance(attr_value, (list, tuple)):
for item in attr_value:
if isinstance(item, object):
_found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
if _found:
target_found = True
# If attribute value is a dictionary, iterate over its values and recurse
# E.g, for a ModuleList, its moudels are stored in a dictionary: ._modules
elif isinstance(attr_value, dict):
for key, value in attr_value.items():
if isinstance(value, object):
_found = modify_method_of_instance(value, target_class_name, target_method_name, new_method, visited_instances)
if _found:
target_found = True
# If attribute value is a set, iterate and recurse
elif isinstance(attr_value, set):
for item in attr_value:
if isinstance(item, object):
_found = modify_method_of_instance(item, target_class_name, target_method_name, new_method, visited_instances)
if _found:
target_found = True
return target_found
def apply(loaded_model, group_size, window_size, enable_flash_attention=False, scale_base=-1, flash_attention_impl="triton"):
'''
loaded_model:
model to apply the self-attention extension.
group_size:
group size for the self-attention extension.
window_size:
window size for the self-attention extension.
scale_base:
base for the scale, equal to pretraining length.
e.g. 4096 for Llama, 8192 for Gemma
Two recommended scale factor:
yarn: https://arxiv.org/abs/2309.00071
log: https://arxiv.org/abs/2202.12172 ; https://kexue.fm/archives/8823
This is helpful while retrieving a long sequence (e.g a long passkey).
But on real-world data, the impact is minor. (e.g. on LongBench, LEval).
The reported results in our paper does not use this scale except for long passkey retrieval.
'''
arch_name = loaded_model.__class__.__name__
if 'Llama' in arch_name:
if enable_flash_attention:
if flash_attention_impl == "flash_attn":
self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_1 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
modifed_2 = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
print("Using flash_attn flash self_extend!!")
if (not modifed_1) or (not modifed_2):
raise Exception(f"Failed to modify the attention method of {arch_name}")
elif flash_attention_impl == "triton":
self_extend_attention_forward = partial(SE.Llama.flash_self_extend_forward_triton,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed = modify_method_of_instance(loaded_model, "LlamaFlashAttention2", "forward", self_extend_attention_forward)
print("Using triton flash self_extend!!")
if (not modifed):
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
raise Exception(f"Need to set the flash_attention_impl to 'flash_attn' or 'triton'.")
else:
self_extend_attention_forward = partial(SE.Llama.self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
# after the default version of attention in 4.36 is LlamaSpdaAttention, but in before 4,36 or in 4.38, it is LlamaAttention
# print("loaded_model", loaded_model)
modifed_2 = modify_method_of_instance(loaded_model, "LlamaAttention", "forward", self_extend_attention_forward)
if not modifed_2:
raise Exception(f"Failed to modify the attention method of {arch_name}")
elif 'Mistral' in arch_name:
# Mistral shares the same architecture with Llama, so the implementation should be exchangable.
if enable_flash_attention:
self_extend_attention_forward = partial(SE.Mistral.flash_self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_1 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
modifed_2 = modify_method_of_instance(loaded_model, "MistralFlashAttention2", "forward", self_extend_attention_forward)
if (not modifed_1) or (not modifed_2):
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
self_extend_attention_forward = partial(SE.Mistral.self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_2 = modify_method_of_instance(loaded_model, "MistralAttention", "forward", self_extend_attention_forward)
if not modifed_2:
raise Exception(f"Failed to modify the attention method of {arch_name}")
elif 'Gemma' in arch_name:
if enable_flash_attention:
self_extend_attention_forward = partial(SE.Gemma.flash_self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_1 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
modifed_2 = modify_method_of_instance(loaded_model, "GemmaFlashAttention2", "forward", self_extend_attention_forward)
if (not modifed_1) or (not modifed_2):
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
self_extend_attention_forward = partial(SE.Gemma.self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_2= modify_method_of_instance(loaded_model, "GemmaAttention", "forward", self_extend_attention_forward)
if not modifed_2:
raise Exception(f"Failed to modify the attention method of {arch_name}")
elif 'Qwen2' in arch_name:
if enable_flash_attention:
self_extend_attention_forward = partial(SE.Qwen2.flash_self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_1 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
modifed_2 = modify_method_of_instance(loaded_model, "Qwen2FlashAttention2", "forward", self_extend_attention_forward)
if (not modifed_1) or (not modifed_2):
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
self_extend_attention_forward = partial(SE.Qwen2.self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_2 = modify_method_of_instance(loaded_model, "Qwen2Attention", "forward", self_extend_attention_forward)
if not modifed_2:
raise Exception(f"Failed to modify the attention method of {arch_name}")
elif 'Phi' in arch_name:
if enable_flash_attention:
self_extend_attention_forward = partial(SE.Phi.flash_self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_1 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "_flash_attention_forward", SE.selfextend_flash_attn.flash_attention2_forward_with_window_size)
modifed_2 = modify_method_of_instance(loaded_model, "PhiFlashAttention2", "forward", self_extend_attention_forward)
if (not modifed_1) or (not modifed_2):
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
self_extend_attention_forward = partial(SE.Phi.self_extend_forward,
group_size_1=group_size,
group_size_2=window_size,
scale_base=scale_base)
modifed_2 = modify_method_of_instance(loaded_model, "PhiAttention", "forward", self_extend_attention_forward)
if not modifed_2:
raise Exception(f"Failed to modify the attention method of {arch_name}")
else:
raise NotImplementedError
|