File size: 262 Bytes
6e73cd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
from torch import nn
FC_CLASS_REGISTRY = {
'torch': nn.Linear,
}
try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY['te'] = te.Linear
except:
pass
|