|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from craftsman.utils.typing import * |
|
from craftsman.utils.checkpoint import checkpoint |
|
|
|
from .utils import init_linear |
|
from .attention import ResidualAttentionBlock |
|
|
|
|
|
class Perceiver(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
n_ctx: int, |
|
width: int, |
|
layers: int, |
|
heads: int, |
|
init_scale: float = 0.25, |
|
qkv_bias: bool = True, |
|
use_flash: bool = False, |
|
use_checkpoint: bool = False |
|
): |
|
super().__init__() |
|
self.n_ctx = n_ctx |
|
self.width = width |
|
self.layers = layers |
|
self.resblocks = nn.ModuleList( |
|
[ |
|
ResidualAttentionBlock( |
|
n_ctx=n_ctx, |
|
width=width, |
|
heads=heads, |
|
init_scale=init_scale, |
|
qkv_bias=qkv_bias, |
|
use_flash=use_flash, |
|
use_checkpoint=use_checkpoint |
|
) |
|
for _ in range(layers) |
|
] |
|
) |
|
|
|
def forward(self, x: torch.Tensor): |
|
for block in self.resblocks: |
|
x = block(x) |
|
return x |