Shaleen123's picture
Upload folder using huggingface_hub
a164e13 verified
import os
from typing import Dict, Optional, Tuple
import torch
from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference, dtype_from_name
from mergekit.graph import Task
from mergekit.io.lazy_tensor_loader import LazyTensorLoader
from mergekit.io.tensor_writer import TensorWriter
from mergekit.options import MergeOptions
class LoaderCache:
loaders: Dict[ModelReference, LazyTensorLoader] = {}
lora_cache_dir: Optional[str] = None
hf_cache_dir: Optional[str] = None
lazy_unpickle: bool = False
trust_remote_code: bool = False
# singleton instance
_instance: Optional["LoaderCache"] = None
def __new__(cls) -> "LoaderCache":
if cls._instance is None:
cls._instance = super(LoaderCache, cls).__new__(cls)
return cls._instance
def get(self, model: ModelReference) -> LazyTensorLoader:
if model not in self.loaders:
merged = model.merged(
cache_dir=self.lora_cache_dir, trust_remote_code=self.trust_remote_code
)
self.loaders[model] = merged.lazy_loader(
cache_dir=self.hf_cache_dir, lazy_unpickle=self.lazy_unpickle
)
return self.loaders[model]
def flush_all(self):
for loader in self.loaders.values():
loader.flush()
def setup(self, options: MergeOptions):
self.lora_cache_dir = options.lora_merge_cache
self.hf_cache_dir = options.transformers_cache
self.lazy_unpickle = options.lazy_unpickle
self.trust_remote_code = options.trust_remote_code
def _normalized_shard_name(path: str) -> int:
name, _ext = os.path.splitext(os.path.basename(path))
return name.lower().replace("pytorch_model", "model")
class LoadTensor(Task[Optional[torch.Tensor]]):
model: ModelReference
tensor: str
dtype: Optional[str] = None
device: Optional[str] = None
optional: bool = False
aliases: Optional[Tuple[str, ...]] = None
def arguments(self) -> Dict[str, Task]:
return {}
def _resolve_name(self, loader: LazyTensorLoader) -> Optional[str]:
all_names = [self.tensor] + list(self.aliases or [])
for name in all_names:
if name in loader.index.tensor_paths:
return name
return None
def execute(self) -> Optional[torch.Tensor]:
loader = LoaderCache().get(self.model)
name = self._resolve_name(loader)
if not name:
if not self.optional:
raise RuntimeError(
f"Tensor {self.tensor} required but not present in model {self.model}"
)
return None
x = loader.get_tensor(name, device=self.device or "cpu")
if self.dtype:
x = x.to(dtype=dtype_from_name(self.dtype))
return x
def priority(self) -> int:
return -1000
def group_label(self) -> Optional[str]:
loader = LoaderCache().get(self.model)
name = self._resolve_name(loader)
if name:
shard_path = loader.index.tensor_paths[name]
return _normalized_shard_name(shard_path)
return None
class GatherTensors(Task[Dict[ModelReference, torch.Tensor]]):
weight_info: ImmutableMap[ModelReference, WeightInfo]
dtype: Optional[str] = None
device: Optional[str] = None
def arguments(self) -> Dict[str, Task]:
return {
f"{str(model)}:{wi.name}": LoadTensor(
model=model,
tensor=wi.name,
dtype=self.dtype,
device=self.device,
optional=wi.optional,
aliases=wi.aliases,
)
for (model, wi) in self.weight_info.items()
}
def group_label(self) -> Optional[str]:
return max(t.group_label() or "" for t in self.arguments().values())
def priority(self) -> int:
return -10
def execute(self, **kwargs) -> Dict[ModelReference, torch.Tensor]:
key2model = {
f"{str(model)}:{wi.name}": model for (model, wi) in self.weight_info.items()
}
return {
key2model[key]: kwargs[key] for key in key2model if kwargs[key] is not None
}
class TensorWriterTask(Task[TensorWriter]):
out_path: str
max_shard_size: int
safe_serialization: bool = True
def arguments(self) -> Dict[str, Task]:
return {}
def execute(self, **_kwargs) -> TensorWriter:
return TensorWriter(
self.out_path,
max_shard_size=self.max_shard_size,
safe_serialization=self.safe_serialization,
)
class SaveTensor(Task[None]):
tensor_name: str
tensor_task: Task
writer_task: TensorWriterTask
clone: bool
optional: bool = False
def arguments(self) -> Dict[str, Task]:
return {"writer": self.writer_task, "tensor": self.tensor_task}
def priority(self) -> int:
return 1000
def group_label(self) -> Optional[str]:
return self.tensor_task.group_label()
def execute(self, writer: TensorWriter, tensor: Optional[torch.Tensor]) -> None:
if tensor is None:
if not self.optional:
raise RuntimeError(f"No value for required tensor {self.tensor_name}")
return
writer.save_tensor(name=self.tensor_name, tensor=tensor, clone=self.clone)
class FinalizeModel(Task[None]):
tensor_save_tasks: Tuple[Task, ...]
writer_task: TensorWriterTask
def arguments(self) -> Dict[str, Task]:
return {
"writer": self.writer_task,
**{f"_unused_{idx}": t for idx, t in enumerate(self.tensor_save_tasks)},
}
def execute(self, writer: TensorWriter, **kwargs) -> None:
writer.finalize()