import torch
import torch.nn as nn
from . import SparseTensor

__all__ = [
    'SparseReLU',
    'SparseSiLU',
    'SparseGELU',
    'SparseActivation'
]


class SparseReLU(nn.ReLU):
    def forward(self, input: SparseTensor) -> SparseTensor:
        return input.replace(super().forward(input.feats))
    

class SparseSiLU(nn.SiLU):
    def forward(self, input: SparseTensor) -> SparseTensor:
        return input.replace(super().forward(input.feats))


class SparseGELU(nn.GELU):
    def forward(self, input: SparseTensor) -> SparseTensor:
        return input.replace(super().forward(input.feats))


class SparseActivation(nn.Module):
    def __init__(self, activation: nn.Module):
        super().__init__()
        self.activation = activation

    def forward(self, input: SparseTensor) -> SparseTensor:
        return input.replace(self.activation(input.feats))