English
naveensp commited on
Commit
b79fb53
·
verified ·
1 Parent(s): 2178aa3

Upload safetensors_util.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. safetensors_util.py +81 -0
safetensors_util.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import pickle
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Optional, Tuple
5
+
6
+ import safetensors.torch
7
+ import torch
8
+
9
+ from .aliases import PathOrStr
10
+
11
+ __all__ = [
12
+ "state_dict_to_safetensors_file",
13
+ "safetensors_file_to_state_dict",
14
+ ]
15
+
16
+
17
+ @dataclass(eq=True, frozen=True)
18
+ class STKey:
19
+ keys: Tuple
20
+ value_is_pickled: bool
21
+
22
+
23
+ def encode_key(key: STKey) -> str:
24
+ b = pickle.dumps((key.keys, key.value_is_pickled))
25
+ b = base64.urlsafe_b64encode(b)
26
+ return str(b, "ASCII")
27
+
28
+
29
+ def decode_key(key: str) -> STKey:
30
+ b = base64.urlsafe_b64decode(key)
31
+ keys, value_is_pickled = pickle.loads(b)
32
+ return STKey(keys, value_is_pickled)
33
+
34
+
35
+ def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]:
36
+ result = {}
37
+ for key, value in d.items():
38
+ if isinstance(value, torch.Tensor):
39
+ result[STKey((key,), False)] = value
40
+ elif isinstance(value, dict):
41
+ value = flatten_dict(value)
42
+ for inner_key, inner_value in value.items():
43
+ result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value
44
+ else:
45
+ pickled = bytearray(pickle.dumps(value))
46
+ pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8)
47
+ result[STKey((key,), True)] = pickled_tensor
48
+ return result
49
+
50
+
51
+ def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict:
52
+ result: Dict = {}
53
+
54
+ for key, value in d.items():
55
+ if key.value_is_pickled:
56
+ value = pickle.loads(value.numpy().data)
57
+
58
+ target_dict = result
59
+ for k in key.keys[:-1]:
60
+ new_target_dict = target_dict.get(k)
61
+ if new_target_dict is None:
62
+ new_target_dict = {}
63
+ target_dict[k] = new_target_dict
64
+ target_dict = new_target_dict
65
+ target_dict[key.keys[-1]] = value
66
+
67
+ return result
68
+
69
+
70
+ def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr):
71
+ state_dict = flatten_dict(state_dict)
72
+ state_dict = {encode_key(k): v for k, v in state_dict.items()}
73
+ safetensors.torch.save_file(state_dict, filename)
74
+
75
+
76
+ def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict:
77
+ if map_location is None:
78
+ map_location = "cpu"
79
+ state_dict = safetensors.torch.load_file(filename, device=map_location)
80
+ state_dict = {decode_key(k): v for k, v in state_dict.items()}
81
+ return unflatten_dict(state_dict)