DS6_UNetMSS3D_woDeform / UNetConfigs.py
soumickmj's picture
Upload UNetMSS3D
d9f6653 verified
raw
history blame contribute delete
733 Bytes
from transformers import PretrainedConfig
from typing import List
class UNet3DConfig(PretrainedConfig):
model_type = "UNet"
def __init__(
self,
in_ch=1,
out_ch=1,
init_features=64,
**kwargs):
self.in_ch = in_ch
self.out_ch = out_ch
self.init_features = init_features
super().__init__(**kwargs)
class UNetMSS3DConfig(PretrainedConfig):
model_type = "UNetMSS"
def __init__(
self,
in_ch=1,
out_ch=1,
init_features=64,
**kwargs):
self.in_ch = in_ch
self.out_ch = out_ch
self.init_features = init_features
super().__init__(**kwargs)