File size: 479 Bytes
3eb682b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
""" Linear layer (alternate definition)
"""
import torch
import torch.nn.functional as F
from torch import nn as nn
class Linear(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting():
bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
else:
return F.linear(input, self.weight, self.bias)
|