from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter class SDXLVAEDecoder(SDVAEDecoder): def __init__(self, upcast_to_float32=True): super().__init__() self.scaling_factor = 0.13025 @staticmethod def state_dict_converter(): return SDXLVAEDecoderStateDictConverter() class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter): def __init__(self): super().__init__() def from_diffusers(self, state_dict): state_dict = super().from_diffusers(state_dict) return state_dict, {"upcast_to_float32": True} def from_civitai(self, state_dict): state_dict = super().from_civitai(state_dict) return state_dict, {"upcast_to_float32": True}