sachit-menon commited on
Commit
d21eb1b
·
verified ·
1 Parent(s): b207f23

Update sd_model.py

Browse files
Files changed (1) hide show
  1. sd_model.py +7 -5
sd_model.py CHANGED
@@ -55,18 +55,18 @@ class LoraConfig:
55
 
56
  @dataclass
57
  class SDModelConfig(BaseModelConfig):
58
- _target_: str = "trainer.models.sd_model.SDModel"
59
  pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
60
  conditioning_dropout_prob: float = 0.05
61
  use_ema: bool = True
62
- concat_all_steps: bool = II("dataset.concat_all_steps")
63
  positional_encoding_type: Optional[str] = "sinusoidal"
64
  positional_encoding_length: Optional[int] = None
65
  image_positional_encoding_type: Optional[str] = None #"sinusoidal"
66
  image_positional_encoding_length: Optional[int] = None
67
  broadcast_positional_encoding: bool = True
68
- sequence_length: Optional[int] = II("dataset.sequence_length") # TODO consider changing interp on next line to this +1?
69
- text_sequence_length: Optional[int] = II("dataset.text_sequence_length")
70
  use_lora: bool = False
71
  # lora_cfg: Any = LoraConfig()
72
  zero_snr: bool = True
@@ -76,8 +76,10 @@ class SDModelConfig(BaseModelConfig):
76
 
77
 
78
  class SDModel(ModelMixin, ConfigMixin, PushToHubMixin):
79
- def __init__(self, cfg: SDModelConfig) -> None:
80
  super().__init__()
 
 
81
  self.cfg = cfg
82
  self.noise_scheduler = DDPMScheduler.from_pretrained(
83
  self.cfg.pretrained_model_name_or_path,
 
55
 
56
  @dataclass
57
  class SDModelConfig(BaseModelConfig):
58
+ _target_: str = "sd_model.SDModel"
59
  pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
60
  conditioning_dropout_prob: float = 0.05
61
  use_ema: bool = True
62
+ concat_all_steps: bool = False
63
  positional_encoding_type: Optional[str] = "sinusoidal"
64
  positional_encoding_length: Optional[int] = None
65
  image_positional_encoding_type: Optional[str] = None #"sinusoidal"
66
  image_positional_encoding_length: Optional[int] = None
67
  broadcast_positional_encoding: bool = True
68
+ sequence_length: Optional[int] = 6
69
+ text_sequence_length: Optional[int] = 7
70
  use_lora: bool = False
71
  # lora_cfg: Any = LoraConfig()
72
  zero_snr: bool = True
 
76
 
77
 
78
  class SDModel(ModelMixin, ConfigMixin, PushToHubMixin):
79
+ def __init__(self, cfg: SDModelConfig = None) -> None:
80
  super().__init__()
81
+ if cfg is not None: # workaround for default
82
+ cfg = SDM
83
  self.cfg = cfg
84
  self.noise_scheduler = DDPMScheduler.from_pretrained(
85
  self.cfg.pretrained_model_name_or_path,