soumickmj's picture
Upload UNetMSS3D
4cd46da verified
raw
history blame
875 Bytes
from transformers import PretrainedConfig
from typing import List
class UNet3DConfig(PretrainedConfig):
model_type = "UNet"
def __init__(
self,
in_ch=1,
out_ch=1,
init_features=64,
dropout_rate=0.5,
**kwargs):
self.in_ch = in_ch
self.out_ch = out_ch
self.init_features = init_features
self.dropout_rate = dropout_rate
super().__init__(**kwargs)
class UNetMSS3DConfig(PretrainedConfig):
model_type = "UNetMSS"
def __init__(
self,
in_ch=1,
out_ch=1,
init_features=64,
dropout_rate=0.5,
**kwargs):
self.in_ch = in_ch
self.out_ch = out_ch
self.init_features = init_features
self.dropout_rate = dropout_rate
super().__init__(**kwargs)