jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
from comfy import model_management
def string_to_dtype(s="none", mode=None):
s = s.lower().strip()
if s in ["default", "as-is"]:
return None
elif s in ["auto", "auto (comfy)"]:
if mode == "vae":
return model_management.vae_device()
elif mode == "text_encoder":
return model_management.text_encoder_dtype()
elif mode == "unet":
return model_management.unet_dtype()
else:
raise NotImplementedError(f"Unknown dtype mode '{mode}'")
elif s in ["none", "auto (hf)", "auto (hf/bnb)"]:
return None
elif s in ["fp32", "float32", "float"]:
return torch.float32
elif s in ["bf16", "bfloat16"]:
return torch.bfloat16
elif s in ["fp16", "float16", "half"]:
return torch.float16
elif "fp8" in s or "float8" in s:
if "e5m2" in s:
return torch.float8_e5m2
elif "e4m3" in s:
return torch.float8_e4m3fn
else:
raise NotImplementedError(f"Unknown 8bit dtype '{s}'")
elif "bnb" in s:
assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'"
return s
elif s is None:
return None
else:
raise NotImplementedError(f"Unknown dtype '{s}'")