import torch from comfy import model_management def string_to_dtype(s="none", mode=None): s = s.lower().strip() if s in ["default", "as-is"]: return None elif s in ["auto", "auto (comfy)"]: if mode == "vae": return model_management.vae_device() elif mode == "text_encoder": return model_management.text_encoder_dtype() elif mode == "unet": return model_management.unet_dtype() else: raise NotImplementedError(f"Unknown dtype mode '{mode}'") elif s in ["none", "auto (hf)", "auto (hf/bnb)"]: return None elif s in ["fp32", "float32", "float"]: return torch.float32 elif s in ["bf16", "bfloat16"]: return torch.bfloat16 elif s in ["fp16", "float16", "half"]: return torch.float16 elif "fp8" in s or "float8" in s: if "e5m2" in s: return torch.float8_e5m2 elif "e4m3" in s: return torch.float8_e4m3fn else: raise NotImplementedError(f"Unknown 8bit dtype '{s}'") elif "bnb" in s: assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'" return s elif s is None: return None else: raise NotImplementedError(f"Unknown dtype '{s}'")