|
import os |
|
from typing import Dict, Optional, Tuple |
|
|
|
import torch |
|
|
|
from mergekit.architecture import WeightInfo |
|
from mergekit.common import ImmutableMap, ModelReference, dtype_from_name |
|
from mergekit.graph import Task |
|
from mergekit.io.lazy_tensor_loader import LazyTensorLoader |
|
from mergekit.io.tensor_writer import TensorWriter |
|
from mergekit.options import MergeOptions |
|
|
|
|
|
class LoaderCache: |
|
loaders: Dict[ModelReference, LazyTensorLoader] = {} |
|
lora_cache_dir: Optional[str] = None |
|
hf_cache_dir: Optional[str] = None |
|
lazy_unpickle: bool = False |
|
trust_remote_code: bool = False |
|
|
|
|
|
_instance: Optional["LoaderCache"] = None |
|
|
|
def __new__(cls) -> "LoaderCache": |
|
if cls._instance is None: |
|
cls._instance = super(LoaderCache, cls).__new__(cls) |
|
return cls._instance |
|
|
|
def get(self, model: ModelReference) -> LazyTensorLoader: |
|
if model not in self.loaders: |
|
merged = model.merged( |
|
cache_dir=self.lora_cache_dir, trust_remote_code=self.trust_remote_code |
|
) |
|
self.loaders[model] = merged.lazy_loader( |
|
cache_dir=self.hf_cache_dir, lazy_unpickle=self.lazy_unpickle |
|
) |
|
return self.loaders[model] |
|
|
|
def flush_all(self): |
|
for loader in self.loaders.values(): |
|
loader.flush() |
|
|
|
def setup(self, options: MergeOptions): |
|
self.lora_cache_dir = options.lora_merge_cache |
|
self.hf_cache_dir = options.transformers_cache |
|
self.lazy_unpickle = options.lazy_unpickle |
|
self.trust_remote_code = options.trust_remote_code |
|
|
|
|
|
def _normalized_shard_name(path: str) -> int: |
|
name, _ext = os.path.splitext(os.path.basename(path)) |
|
return name.lower().replace("pytorch_model", "model") |
|
|
|
|
|
class LoadTensor(Task[Optional[torch.Tensor]]): |
|
model: ModelReference |
|
tensor: str |
|
dtype: Optional[str] = None |
|
device: Optional[str] = None |
|
optional: bool = False |
|
aliases: Optional[Tuple[str, ...]] = None |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return {} |
|
|
|
def _resolve_name(self, loader: LazyTensorLoader) -> Optional[str]: |
|
all_names = [self.tensor] + list(self.aliases or []) |
|
for name in all_names: |
|
if name in loader.index.tensor_paths: |
|
return name |
|
return None |
|
|
|
def execute(self) -> Optional[torch.Tensor]: |
|
loader = LoaderCache().get(self.model) |
|
name = self._resolve_name(loader) |
|
if not name: |
|
if not self.optional: |
|
raise RuntimeError( |
|
f"Tensor {self.tensor} required but not present in model {self.model}" |
|
) |
|
return None |
|
|
|
x = loader.get_tensor(name, device=self.device or "cpu") |
|
if self.dtype: |
|
x = x.to(dtype=dtype_from_name(self.dtype)) |
|
return x |
|
|
|
def priority(self) -> int: |
|
return -1000 |
|
|
|
def group_label(self) -> Optional[str]: |
|
loader = LoaderCache().get(self.model) |
|
name = self._resolve_name(loader) |
|
if name: |
|
shard_path = loader.index.tensor_paths[name] |
|
return _normalized_shard_name(shard_path) |
|
return None |
|
|
|
|
|
class GatherTensors(Task[Dict[ModelReference, torch.Tensor]]): |
|
weight_info: ImmutableMap[ModelReference, WeightInfo] |
|
dtype: Optional[str] = None |
|
device: Optional[str] = None |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return { |
|
f"{str(model)}:{wi.name}": LoadTensor( |
|
model=model, |
|
tensor=wi.name, |
|
dtype=self.dtype, |
|
device=self.device, |
|
optional=wi.optional, |
|
aliases=wi.aliases, |
|
) |
|
for (model, wi) in self.weight_info.items() |
|
} |
|
|
|
def group_label(self) -> Optional[str]: |
|
return max(t.group_label() or "" for t in self.arguments().values()) |
|
|
|
def priority(self) -> int: |
|
return -10 |
|
|
|
def execute(self, **kwargs) -> Dict[ModelReference, torch.Tensor]: |
|
key2model = { |
|
f"{str(model)}:{wi.name}": model for (model, wi) in self.weight_info.items() |
|
} |
|
return { |
|
key2model[key]: kwargs[key] for key in key2model if kwargs[key] is not None |
|
} |
|
|
|
|
|
class TensorWriterTask(Task[TensorWriter]): |
|
out_path: str |
|
max_shard_size: int |
|
safe_serialization: bool = True |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return {} |
|
|
|
def execute(self, **_kwargs) -> TensorWriter: |
|
return TensorWriter( |
|
self.out_path, |
|
max_shard_size=self.max_shard_size, |
|
safe_serialization=self.safe_serialization, |
|
) |
|
|
|
|
|
class SaveTensor(Task[None]): |
|
tensor_name: str |
|
tensor_task: Task |
|
writer_task: TensorWriterTask |
|
clone: bool |
|
optional: bool = False |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return {"writer": self.writer_task, "tensor": self.tensor_task} |
|
|
|
def priority(self) -> int: |
|
return 1000 |
|
|
|
def group_label(self) -> Optional[str]: |
|
return self.tensor_task.group_label() |
|
|
|
def execute(self, writer: TensorWriter, tensor: Optional[torch.Tensor]) -> None: |
|
if tensor is None: |
|
if not self.optional: |
|
raise RuntimeError(f"No value for required tensor {self.tensor_name}") |
|
return |
|
writer.save_tensor(name=self.tensor_name, tensor=tensor, clone=self.clone) |
|
|
|
|
|
class FinalizeModel(Task[None]): |
|
tensor_save_tasks: Tuple[Task, ...] |
|
writer_task: TensorWriterTask |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return { |
|
"writer": self.writer_task, |
|
**{f"_unused_{idx}": t for idx, t in enumerate(self.tensor_save_tasks)}, |
|
} |
|
|
|
def execute(self, writer: TensorWriter, **kwargs) -> None: |
|
writer.finalize() |
|
|