import torch | |
import torch.nn as nn | |
class LinearImplicitBackward(nn.Module): | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return F.linear(input, self.weight, self.bias) | |
class LinearBackward(nn.Module): | |
has_backward = True | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return F.linear(input, self.weight, self.bias) | |
class LinearNoBackward(nn.Module): | |
has_backward = False | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return F.linear(input, self.weight, self.bias) | |
__all__ = ["LinearImplicitBackward", "LinearBackward", "LinearNoBackward"] | |