danieldk HF Staff commited on
Commit
9642b02
·
1 Parent(s): 593ea98
build/torch-universal/backward_marker_test/layers.py CHANGED
@@ -4,11 +4,14 @@ from torch.nn import functional as F
4
 
5
 
6
  class LinearImplicitBackward(nn.Module):
 
 
7
  def forward(self, input: torch.Tensor) -> torch.Tensor:
8
  return F.linear(input, self.weight, self.bias)
9
 
10
 
11
  class LinearBackward(nn.Module):
 
12
  has_backward = True
13
 
14
  def forward(self, input: torch.Tensor) -> torch.Tensor:
@@ -16,6 +19,7 @@ class LinearBackward(nn.Module):
16
 
17
 
18
  class LinearNoBackward(nn.Module):
 
19
  has_backward = False
20
 
21
  def forward(self, input: torch.Tensor) -> torch.Tensor:
 
4
 
5
 
6
  class LinearImplicitBackward(nn.Module):
7
+ can_torch_compile = True
8
+
9
  def forward(self, input: torch.Tensor) -> torch.Tensor:
10
  return F.linear(input, self.weight, self.bias)
11
 
12
 
13
  class LinearBackward(nn.Module):
14
+ can_torch_compile = True
15
  has_backward = True
16
 
17
  def forward(self, input: torch.Tensor) -> torch.Tensor:
 
19
 
20
 
21
  class LinearNoBackward(nn.Module):
22
+ can_torch_compile = True
23
  has_backward = False
24
 
25
  def forward(self, input: torch.Tensor) -> torch.Tensor: