File size: 835 Bytes
96bbf6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from transformers import PretrainedConfig
class ProbUNetConfig(PretrainedConfig):
model_type = "ProbUNet"
def __init__(
self,
dim=2,
in_channels=1,
out_channels=1,
num_feature_maps=24,
latent_size=3,
depth=5,
latent_distribution="normal",
no_outact_op=False,
prob_injection_at="end",
**kwargs):
self.dim = dim
self.in_channels = in_channels
self.out_channels = out_channels
self.num_feature_maps = num_feature_maps
self.latent_size = latent_size
self.depth = depth
self.latent_distribution = latent_distribution
self.no_outact_op = no_outact_op
self.prob_injection_at = prob_injection_at
super().__init__(**kwargs) |