custom-transmxm / configuration_transmxm.py
Huhujingjing's picture
Upload 2 files
3717306
from transformers import PretrainedConfig
from typing import List
class TransmxmConfig(PretrainedConfig):
model_type = "transmxm"
def __init__(
self,
dim: int = 128,
n_layer: int = 6,
cutoff: float = 5.0,
num_spherical: int = 7,
num_radial: int = 6,
envelope_exponent: int = 5,
smiles: List[str] = None,
processor_class: str = "SmilesProcessor",
**kwargs,
):
self.dim = dim # the dimension of input feature
self.n_layer = n_layer # the number of GCN layers
self.cutoff = cutoff # the cutoff distance for neighbor searching
self.num_spherical = num_spherical # the number of spherical harmonics
self.num_radial = num_radial # the number of radial basis
self.envelope_exponent = envelope_exponent # the envelope exponent
self.smiles = smiles # process smiles
self.processor_class = processor_class
super().__init__(**kwargs)
if __name__ == "__main__":
transmxm_config = TransmxmConfig(
dim=128,
n_layer=6,
cutoff=5.0,
num_spherical=7,
num_radial=6,
envelope_exponent=5,
smiles=["C", "CC", "CCC"],
processor_class="SmilesProcessor"
)
transmxm_config.save_pretrained("custom-transmxm")