|
from transformers import PretrainedConfig |
|
|
|
class ProbUNetConfig(PretrainedConfig): |
|
model_type = "ProbUNet" |
|
def __init__( |
|
self, |
|
dim=2, |
|
in_channels=1, |
|
out_channels=1, |
|
num_feature_maps=24, |
|
latent_size=3, |
|
depth=5, |
|
latent_distribution="normal", |
|
no_outact_op=False, |
|
prob_injection_at="end", |
|
**kwargs): |
|
self.dim = dim |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.num_feature_maps = num_feature_maps |
|
self.latent_size = latent_size |
|
self.depth = depth |
|
self.latent_distribution = latent_distribution |
|
self.no_outact_op = no_outact_op |
|
self.prob_injection_at = prob_injection_at |
|
super().__init__(**kwargs) |