English
naveensp commited on
Commit
17776d2
·
verified ·
1 Parent(s): efb31ab

Delete safetensors_util.py

Browse files
Files changed (1) hide show
  1. safetensors_util.py +0 -81
safetensors_util.py DELETED
@@ -1,81 +0,0 @@
1
- import base64
2
- import pickle
3
- from dataclasses import dataclass
4
- from typing import Dict, Optional, Tuple
5
-
6
- import safetensors.torch
7
- import torch
8
-
9
- from .aliases import PathOrStr
10
-
11
- __all__ = [
12
- "state_dict_to_safetensors_file",
13
- "safetensors_file_to_state_dict",
14
- ]
15
-
16
-
17
- @dataclass(eq=True, frozen=True)
18
- class STKey:
19
- keys: Tuple
20
- value_is_pickled: bool
21
-
22
-
23
- def encode_key(key: STKey) -> str:
24
- b = pickle.dumps((key.keys, key.value_is_pickled))
25
- b = base64.urlsafe_b64encode(b)
26
- return str(b, "ASCII")
27
-
28
-
29
- def decode_key(key: str) -> STKey:
30
- b = base64.urlsafe_b64decode(key)
31
- keys, value_is_pickled = pickle.loads(b)
32
- return STKey(keys, value_is_pickled)
33
-
34
-
35
- def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]:
36
- result = {}
37
- for key, value in d.items():
38
- if isinstance(value, torch.Tensor):
39
- result[STKey((key,), False)] = value
40
- elif isinstance(value, dict):
41
- value = flatten_dict(value)
42
- for inner_key, inner_value in value.items():
43
- result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value
44
- else:
45
- pickled = bytearray(pickle.dumps(value))
46
- pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8)
47
- result[STKey((key,), True)] = pickled_tensor
48
- return result
49
-
50
-
51
- def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict:
52
- result: Dict = {}
53
-
54
- for key, value in d.items():
55
- if key.value_is_pickled:
56
- value = pickle.loads(value.numpy().data)
57
-
58
- target_dict = result
59
- for k in key.keys[:-1]:
60
- new_target_dict = target_dict.get(k)
61
- if new_target_dict is None:
62
- new_target_dict = {}
63
- target_dict[k] = new_target_dict
64
- target_dict = new_target_dict
65
- target_dict[key.keys[-1]] = value
66
-
67
- return result
68
-
69
-
70
- def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr):
71
- state_dict = flatten_dict(state_dict)
72
- state_dict = {encode_key(k): v for k, v in state_dict.items()}
73
- safetensors.torch.save_file(state_dict, filename)
74
-
75
-
76
- def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict:
77
- if map_location is None:
78
- map_location = "cpu"
79
- state_dict = safetensors.torch.load_file(filename, device=map_location)
80
- state_dict = {decode_key(k): v for k, v in state_dict.items()}
81
- return unflatten_dict(state_dict)