File size: 3,037 Bytes
f9158ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import concurrent
import time
import torch

from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from datetime import timedelta
from loguru import logger
from pathlib import Path
from safetensors.torch import load_file, save_file
from typing import Dict, List


def check_file_size(source_file: Path, target_file: Path):
    """
    Check that two files are close in size
    """
    source_file_size = source_file.stat().st_size
    target_file_size = target_file.stat().st_size

    if (source_file_size - target_file_size) / source_file_size > 0.01:
        raise RuntimeError(
            f"""The file size different is more than 1%:
         - {source_file}: {source_file_size}
         - {target_file}: {target_file_size}
         """
        )


def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
    """
    For a Dict of tensors, check if two or more tensors point to the same underlying memory and
    remove them
    """
    ptrs = defaultdict(list)
    for k, v in tensors.items():
        ptrs[v.data_ptr()].append(k)

    # Iterate over all found memory addresses
    for ptr, names in ptrs.items():
        if len(names) > 1:
            # Multiple tensors are point to the same memory
            # Only keep the first tensor
            for name in names[1:]:
                tensors.pop(name)


def convert_file(pt_file: Path, st_file: Path):
    """
    Convert a pytorch file to a safetensors file
    """
    logger.info(f"Convert {pt_file} to {st_file}.")

    pt_state = torch.load(pt_file, map_location="cpu")
    if "state_dict" in pt_state:
        pt_state = pt_state["state_dict"]

    remove_shared_pointers(pt_state)

    # Tensors need to be contiguous
    pt_state = {k: v.contiguous() for k, v in pt_state.items()}

    st_file.parent.mkdir(parents=True, exist_ok=True)
    save_file(pt_state, str(st_file), metadata={"format": "pt"})

    # Check that both files are close in size
    check_file_size(pt_file, st_file)

    # Load safetensors state
    st_state = load_file(str(st_file))
    for k in st_state:
        pt_tensor = pt_state[k]
        st_tensor = st_state[k]
        if not torch.equal(pt_tensor, st_tensor):
            raise RuntimeError(f"The output tensors do not match for key {k}")


def convert_files(pt_files: List[Path], st_files: List[Path]):
    assert len(pt_files) == len(st_files)

    executor = ThreadPoolExecutor(max_workers=5)
    futures = [
        executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
        for pt_file, st_file in zip(pt_files, st_files)
    ]

    # We do this instead of using tqdm because we want to parse the logs with the launcher
    start_time = time.time()
    for i, future in enumerate(concurrent.futures.as_completed(futures)):
        elapsed = timedelta(seconds=int(time.time() - start_time))
        remaining = len(futures) - (i + 1)
        eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0

        logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")