soumickmj's picture
Upload ProbUNet
7b52c3c verified
raw
history blame contribute delete
835 Bytes
from transformers import PretrainedConfig
class ProbUNetConfig(PretrainedConfig):
model_type = "ProbUNet"
def __init__(
self,
dim=2,
in_channels=1,
out_channels=1,
num_feature_maps=24,
latent_size=3,
depth=5,
latent_distribution="normal",
no_outact_op=False,
prob_injection_at="end",
**kwargs):
self.dim = dim
self.in_channels = in_channels
self.out_channels = out_channels
self.num_feature_maps = num_feature_maps
self.latent_size = latent_size
self.depth = depth
self.latent_distribution = latent_distribution
self.no_outact_op = no_outact_op
self.prob_injection_at = prob_injection_at
super().__init__(**kwargs)