import torch.nn as nn class Net(nn.Module): def __init__(self,input_size,hidden_size1,hidden_size2,output_size): super(Net, self).__init__() self.layer1=nn.Linear(input_size,hidden_size1) self.layer2=nn.Linear(hidden_size1,hidden_size2) self.layer3=nn.Linear(hidden_size2,output_size) self.relu=nn.ReLU() def forward(self,x): x=x.flatten(start_dim=1) x=self.layer1(x) x=self.relu(x) x=self.layer2(x) x=self.relu(x) x=self.layer3(x) return x