Spaces:
Runtime error
Runtime error
File size: 755 Bytes
1b8bd99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
from torch import nn
# Neural Network
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(5, 5)),
nn.Tanh(),
nn.AvgPool2d(2, 2),
nn.Conv2d(in_channels=4, out_channels=12, kernel_size=(5, 5)),
nn.Tanh(),
nn.AvgPool2d(2, 2)
)
self.linear = nn.Sequential(
nn.Linear(4*4*12,10)
)
def forward(self, x):
x = self.convs(x)
x = torch.flatten(x, 1)
return self.linear(x) |