|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import importlib |
|
import inspect |
|
from typing import Any, Dict |
|
|
|
|
|
def code_patcher(module_name: str, function: Any, new_function_name: str, modifications: Dict[str, str]): |
|
""" |
|
This function is used in this project to |
|
This function helps updating a module given the function name and the modifications to be done on this function |
|
Once you use code_patcher(), you just need to override the function with its new version using the new function |
|
name. |
|
:param module_name: the module to be updated |
|
:param function: the function to be updated in the given module |
|
:param modifications: a dictionary containing all the modifications to be done, keys are the source/original code |
|
and values are the new code to be used to replace source code |
|
return: Whether it succeeded to update the given function |
|
Example: |
|
if you're updating the forward function in T5Attention transformers |
|
`transformers.models.t5.modeling_t5.T5Attention.forward` and using `updatedForward` as new function name, you can |
|
do: |
|
>>> import transformers |
|
>>> code_patcher(module_name="transformers.models.t5.modeling_t5", |
|
>>> function=transformers.models.t5.modeling_t5.T5Attention.forward , |
|
>>> new_function_name="updatedForward", |
|
>>> modifications=dict("return outputs", "return True") |
|
>>> ) |
|
>>> transformers.models.t5.modeling_t5.T5Attention.forward = transformers.models.t5.modeling_t5.updatedForward |
|
""" |
|
model_module = importlib.import_module(name=module_name) |
|
function_code = inspect.getsource(function) |
|
for src_code, new_code in modifications.items(): |
|
assert src_code in function_code, ( |
|
f"Failed to update function {function.__name__} in module {module_name}: " |
|
f'\n"{src_code}" was not found in {function.__name__} source code' |
|
) |
|
function_code = function_code.replace(src_code, new_code) |
|
function_code = function_code.replace(f"def {function.__name__}(", f"def {new_function_name}(") |
|
|
|
exec(inspect.cleandoc("\n" + function_code), model_module.__dict__, model_module.__dict__) |
|
|