File size: 1,095 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from comfy import model_management

def string_to_dtype(s="none", mode=None):
	s = s.lower().strip()
	if s in ["default", "as-is"]:
		return None
	elif s in ["auto", "auto (comfy)"]:
		if mode == "vae":
			return model_management.vae_device()
		elif mode == "text_encoder":
			return model_management.text_encoder_dtype()
		elif mode == "unet":
			return model_management.unet_dtype()
		else:
			raise NotImplementedError(f"Unknown dtype mode '{mode}'")
	elif s in ["none", "auto (hf)", "auto (hf/bnb)"]:
		return None
	elif s in ["fp32", "float32", "float"]:
		return torch.float32
	elif s in ["bf16", "bfloat16"]:
		return torch.bfloat16
	elif s in ["fp16", "float16", "half"]:
		return torch.float16
	elif "fp8" in s or "float8" in s:
		if "e5m2" in s:
			return torch.float8_e5m2
		elif "e4m3" in s:
			return torch.float8_e4m3fn
		else:
			raise NotImplementedError(f"Unknown 8bit dtype '{s}'")
	elif "bnb" in s:
		assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'"
		return s
	elif s is None:
		return None
	else:
		raise NotImplementedError(f"Unknown dtype '{s}'")