Spaces:
Runtime error
Runtime error
File size: 5,251 Bytes
ad93086 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gguf
import torch
import os
import json
import safetensors.torch
import backend.misc.checkpoint_pickle
from backend.operations_gguf import ParameterGGUF
def read_arbitrary_config(directory):
config_path = os.path.join(directory, 'config.json')
if not os.path.exists(config_path):
raise FileNotFoundError(f"No config.json file found in the directory: {directory}")
with open(config_path, 'rt', encoding='utf-8') as file:
config_data = json.load(file)
return config_data
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
elif ckpt.lower().endswith(".gguf"):
reader = gguf.GGUFReader(ckpt)
sd = {}
for tensor in reader.tensors:
sd[str(tensor.name)] = ParameterGGUF(tensor)
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=backend.misc.checkpoint_pickle)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
return sd
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
def set_attr_raw(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], value)
def copy_to_param(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
attrs = attr.split(".")
for name in attrs:
obj = getattr(obj, name)
return obj
def get_attr_with_parent(obj, attr):
attrs = attr.split(".")
parent = obj
name = None
for name in attrs:
parent = obj
obj = getattr(obj, name)
return parent, name, obj
def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
params += sd[k].nelement()
return params
def tensor2parameter(x):
if isinstance(x, torch.nn.Parameter):
return x
else:
return torch.nn.Parameter(x, requires_grad=False)
def fp16_fix(x):
# An interesting trick to avoid fp16 overflow
# Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114
# Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180
if x.dtype in [torch.float16]:
return x.clip(-32768.0, 32768.0)
return x
def dtype_to_element_size(dtype):
if isinstance(dtype, torch.dtype):
return torch.tensor([], dtype=dtype).element_size()
else:
raise ValueError(f"Invalid dtype: {dtype}")
def nested_compute_size(obj, element_size):
module_mem = 0
if isinstance(obj, dict):
for key in obj:
module_mem += nested_compute_size(obj[key], element_size)
elif isinstance(obj, list) or isinstance(obj, tuple):
for i in range(len(obj)):
module_mem += nested_compute_size(obj[i], element_size)
elif isinstance(obj, torch.Tensor):
module_mem += obj.nelement() * element_size
return module_mem
def nested_move_to_device(obj, **kwargs):
if isinstance(obj, dict):
for key in obj:
obj[key] = nested_move_to_device(obj[key], **kwargs)
elif isinstance(obj, list):
for i in range(len(obj)):
obj[i] = nested_move_to_device(obj[i], **kwargs)
elif isinstance(obj, tuple):
obj = tuple(nested_move_to_device(i, **kwargs) for i in obj)
elif isinstance(obj, torch.Tensor):
return obj.to(**kwargs)
return obj
def get_state_dict_after_quant(model, prefix=''):
for m in model.modules():
if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized'):
if not m.weight.bnb_quantized:
original_device = m.weight.device
m.cuda()
m.to(original_device)
sd = model.state_dict()
sd = {(prefix + k): v.clone() for k, v in sd.items()}
return sd
def beautiful_print_gguf_state_dict_statics(state_dict):
from gguf.constants import GGMLQuantizationType
type_counts = {}
for k, v in state_dict.items():
gguf_cls = getattr(v, 'gguf_cls', None)
if gguf_cls is not None:
type_name = gguf_cls.__name__
if type_name in type_counts:
type_counts[type_name] += 1
else:
type_counts[type_name] = 1
print(f'GGUF state dict: {type_counts}')
return
|