|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
class UNet3DConfig(PretrainedConfig): |
|
model_type = "UNet" |
|
def __init__( |
|
self, |
|
in_ch=1, |
|
out_ch=1, |
|
init_features=64, |
|
dropout_rate=0.5, |
|
**kwargs): |
|
self.in_ch = in_ch |
|
self.out_ch = out_ch |
|
self.init_features = init_features |
|
self.dropout_rate = dropout_rate |
|
super().__init__(**kwargs) |
|
|
|
class UNetMSS3DConfig(PretrainedConfig): |
|
model_type = "UNetMSS" |
|
def __init__( |
|
self, |
|
in_ch=1, |
|
out_ch=1, |
|
init_features=64, |
|
dropout_rate=0.5, |
|
**kwargs): |
|
self.in_ch = in_ch |
|
self.out_ch = out_ch |
|
self.init_features = init_features |
|
self.dropout_rate = dropout_rate |
|
super().__init__(**kwargs) |