crystal-technologies's picture
Upload 2711 files
6e73cd3
raw
history blame contribute delete
262 Bytes
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
from torch import nn
FC_CLASS_REGISTRY = {
'torch': nn.Linear,
}
try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY['te'] = te.Linear
except:
pass