potterGPT-v0 / model /feed_forward.py
nullHawk's picture
add: v0
9fe7c42 verified
raw
history blame contribute delete
418 Bytes
import torch
import torch.nn as nn
class FeedForward(nn.Module):
def __init__(self, Config):
super().__init__()
self.net = nn.Sequential(
nn.Linear(Config.n_embed,Config.n_embed * 4),
nn.ReLU(),
nn.Linear(Config.n_embed * 4, Config.n_embed), # projection
nn.Dropout(Config.block_dropout)
)
def forward(self,x):
return self.net(x)