File size: 971 Bytes
67c46fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch


class ForwardAdaptor(torch.nn.Module):
    """Wrapped module to parallelize specified method

    torch.nn.DataParallel parallelizes only "forward()"
    and, maybe, the method having the other name can't be applied
    except for wrapping the module just like this class.

    Examples:
        >>> class A(torch.nn.Module):
        ...     def foo(self, x):
        ...         ...
        >>> model = A()
        >>> model = ForwardAdaptor(model, "foo")
        >>> model = torch.nn.DataParallel(model, device_ids=[0, 1])
        >>> x = torch.randn(2, 10)
        >>> model(x)
    """

    def __init__(self, module: torch.nn.Module, name: str):
        super().__init__()
        self.module = module
        self.name = name
        if not hasattr(module, name):
            raise ValueError(f"{module} doesn't have {name}")

    def forward(self, *args, **kwargs):
        func = getattr(self.module, self.name)
        return func(*args, **kwargs)