Build
Browse files
build/torch-universal/backward_marker_test/layers.py
CHANGED
@@ -4,11 +4,14 @@ from torch.nn import functional as F
|
|
4 |
|
5 |
|
6 |
class LinearImplicitBackward(nn.Module):
|
|
|
|
|
7 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
8 |
return F.linear(input, self.weight, self.bias)
|
9 |
|
10 |
|
11 |
class LinearBackward(nn.Module):
|
|
|
12 |
has_backward = True
|
13 |
|
14 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
@@ -16,6 +19,7 @@ class LinearBackward(nn.Module):
|
|
16 |
|
17 |
|
18 |
class LinearNoBackward(nn.Module):
|
|
|
19 |
has_backward = False
|
20 |
|
21 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
4 |
|
5 |
|
6 |
class LinearImplicitBackward(nn.Module):
|
7 |
+
can_torch_compile = True
|
8 |
+
|
9 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
10 |
return F.linear(input, self.weight, self.bias)
|
11 |
|
12 |
|
13 |
class LinearBackward(nn.Module):
|
14 |
+
can_torch_compile = True
|
15 |
has_backward = True
|
16 |
|
17 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
19 |
|
20 |
|
21 |
class LinearNoBackward(nn.Module):
|
22 |
+
can_torch_compile = True
|
23 |
has_backward = False
|
24 |
|
25 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|