File size: 2,828 Bytes
e0c2d04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#  Copyright 2022, Lefebvre Dalloz Services
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
import importlib
import inspect
from typing import Any, Dict


def code_patcher(module_name: str, function: Any, new_function_name: str, modifications: Dict[str, str]):
    """
    This function is used in this project to
    This function helps updating a module given the function name and the modifications to be done on this function
    Once you use code_patcher(), you just need to override the function with its new version using the new function
    name.
    :param module_name: the module to be updated
    :param function: the function to be updated in the given module
    :param modifications: a dictionary containing all the modifications to be done, keys are the source/original code
    and values are the new code to be used to replace source code
    return: Whether it succeeded to update the given function
    Example:
    if you're updating the forward function in T5Attention transformers
    `transformers.models.t5.modeling_t5.T5Attention.forward` and using `updatedForward` as new function name, you can
    do:
        >>> import transformers
        >>> code_patcher(module_name="transformers.models.t5.modeling_t5",
        >>>            function=transformers.models.t5.modeling_t5.T5Attention.forward ,
        >>>            new_function_name="updatedForward",
        >>>            modifications=dict("return outputs", "return True")
        >>>            )
        >>> transformers.models.t5.modeling_t5.T5Attention.forward = transformers.models.t5.modeling_t5.updatedForward
    """
    model_module = importlib.import_module(name=module_name)
    function_code = inspect.getsource(function)
    for src_code, new_code in modifications.items():
        assert src_code in function_code, (
            f"Failed to update function {function.__name__} in module {module_name}: "
            f'\n"{src_code}" was not found in {function.__name__} source code'
        )
        function_code = function_code.replace(src_code, new_code)
    function_code = function_code.replace(f"def {function.__name__}(", f"def {new_function_name}(")
    # adding the newline at the beginning for cleandoc constraint
    exec(inspect.cleandoc("\n" + function_code), model_module.__dict__, model_module.__dict__)