|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from functools import lru_cache |
|
from typing import Any, List, Optional |
|
|
|
from mergekit import merge_methods |
|
from mergekit.architecture import ( |
|
ArchitectureInfo, |
|
ConfiguredArchitectureInfo, |
|
WeightInfo, |
|
) |
|
from mergekit.common import ImmutableMap, ModelReference |
|
from mergekit.config import ( |
|
ConfigReader, |
|
InputSliceDefinition, |
|
MergeConfiguration, |
|
OutputSliceDefinition, |
|
) |
|
from mergekit.graph import Task |
|
from mergekit.io.tasks import ( |
|
FinalizeModel, |
|
GatherTensors, |
|
LoaderCache, |
|
SaveTensor, |
|
TensorWriterTask, |
|
) |
|
from mergekit.merge_methods import MergeMethod |
|
from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge |
|
from mergekit.options import MergeOptions |
|
from mergekit.tokenizer import BuildTokenizer |
|
|
|
|
|
class MergePlanner: |
|
config: MergeConfiguration |
|
arch_info: ArchitectureInfo |
|
options: MergeOptions |
|
out_model_config: Any |
|
_writer_task: TensorWriterTask |
|
_method: MergeMethod |
|
_tasks: List[Task] = [] |
|
_current_layers: int = 0 |
|
_tokenizer_task: Optional[BuildTokenizer] = None |
|
|
|
def __init__( |
|
self, |
|
config: MergeConfiguration, |
|
arch_info: ArchitectureInfo, |
|
out_path: str, |
|
options: MergeOptions, |
|
out_model_config: Any, |
|
): |
|
self.config = config |
|
self.arch_info = arch_info |
|
self.options = options |
|
self.out_model_config = out_model_config |
|
self._method = merge_methods.get(config.merge_method) |
|
self._writer_task = TensorWriterTask( |
|
out_path=out_path, |
|
max_shard_size=options.out_shard_size, |
|
safe_serialization=options.safe_serialization, |
|
) |
|
|
|
if config.tokenizer_source: |
|
self._tokenizer_task = BuildTokenizer( |
|
base_model=config.base_model, |
|
referenced_models=tuple(config.referenced_models()), |
|
tokenizer_source=config.tokenizer_source, |
|
trust_remote_code=options.trust_remote_code, |
|
) |
|
|
|
@lru_cache |
|
def model_arch_info(self, model: ModelReference): |
|
return ConfiguredArchitectureInfo( |
|
info=self.arch_info, |
|
config=model.config(trust_remote_code=self.options.trust_remote_code), |
|
) |
|
|
|
def normalize_config(self): |
|
base_model = self.config.base_model |
|
|
|
|
|
if self.config.models: |
|
if self.config.slices: |
|
raise RuntimeError( |
|
"Must specify either models to merge or output slices" |
|
) |
|
|
|
slices_in = [] |
|
base_included = False |
|
|
|
for model_in in self.config.models: |
|
if base_model and model_in.model == base_model: |
|
base_included = True |
|
|
|
model_info = self.model_arch_info(model_in.model) |
|
slices_in.append( |
|
InputSliceDefinition( |
|
layer_range=[0, model_info.num_layers()], |
|
model=model_in.model, |
|
parameters=model_in.parameters, |
|
) |
|
) |
|
|
|
if base_model and not base_included: |
|
logging.info("Base model specified but not in input models - adding") |
|
base_info = self.model_arch_info(base_model) |
|
slices_in.append( |
|
InputSliceDefinition( |
|
layer_range=[0, base_info.num_layers()], |
|
model=base_model, |
|
) |
|
) |
|
|
|
self.config.slices = [OutputSliceDefinition(sources=slices_in)] |
|
self.config.models = None |
|
|
|
def plan_tensor( |
|
self, |
|
weight: WeightInfo, |
|
weights_in: List[WeightInfo], |
|
models: List[ModelReference], |
|
cfg_reader: ConfigReader, |
|
): |
|
if weight.optional: |
|
|
|
any_weight = False |
|
for model, w_in in zip(models, weights_in): |
|
index = LoaderCache().get(model).index |
|
if w_in.name in index.tensor_paths: |
|
any_weight = True |
|
break |
|
|
|
if not any_weight: |
|
logging.info(f"Skipping optional weight {weight.name}") |
|
return |
|
|
|
tensor_merge_method = self._method |
|
if self._tokenizer_task and weight.is_embed: |
|
tensor_merge_method = TokenizerPermutationMerge( |
|
tokenizer_task=self._tokenizer_task |
|
) |
|
|
|
cfg_g = cfg_reader.for_tensor(weight.name) |
|
global_params = {} |
|
for p in tensor_merge_method.parameters(): |
|
global_params[p.name] = cfg_g.parameter( |
|
p.name, model=None, required=p.required, default=p.default_value |
|
) |
|
|
|
tensor_params = {} |
|
for model, weight_in in zip(models, weights_in): |
|
is_base = model == cfg_reader.config.base_model |
|
tensor_params[model] = {} |
|
cfg_m = cfg_reader.for_tensor(weight_in.name) |
|
for p in tensor_merge_method.tensor_parameters(): |
|
tensor_params[model][p.name] = cfg_m.parameter( |
|
p.name, |
|
model=model, |
|
required=p.required and not is_base, |
|
default=p.default_value, |
|
) |
|
|
|
gather_tensors = GatherTensors( |
|
weight_info=ImmutableMap(data=dict(zip(models, weights_in))), |
|
dtype=self.config.dtype, |
|
) |
|
|
|
tensor_task = tensor_merge_method.make_task( |
|
output_weight=weight, |
|
tensors=gather_tensors, |
|
parameters=ImmutableMap(data=global_params), |
|
tensor_parameters=ImmutableMap( |
|
data={ |
|
key: ImmutableMap(data=tensor_params[key]) for key in tensor_params |
|
} |
|
), |
|
base_model=self.config.base_model, |
|
) |
|
save_task = SaveTensor( |
|
tensor_name=weight.name, |
|
tensor_task=tensor_task, |
|
writer_task=self._writer_task, |
|
clone=self.options.clone_tensors, |
|
optional=weight.optional, |
|
) |
|
self._tasks.append(save_task) |
|
|
|
def plan_layer( |
|
self, |
|
sources: List[InputSliceDefinition], |
|
layer_offset: int, |
|
t: float, |
|
cfg_reader: ConfigReader, |
|
): |
|
weights_out: List[WeightInfo] = self.arch_info.layer_weights( |
|
index=self._current_layers, |
|
config=self.out_model_config, |
|
) |
|
weights_in: List[List[WeightInfo]] = [ |
|
self.model_arch_info(s.model).layer_weights( |
|
index=s.layer_range[0] + layer_offset |
|
) |
|
for s in sources |
|
] |
|
|
|
for idx, w_o in enumerate(weights_out): |
|
self.plan_tensor( |
|
weight=w_o, |
|
weights_in=[weights_in[j][idx] for j in range(len(weights_in))], |
|
models=[s.model for s in sources], |
|
cfg_reader=cfg_reader.with_t(t), |
|
) |
|
|
|
self._current_layers += 1 |
|
|
|
def plan_slice(self, definition: OutputSliceDefinition): |
|
slice_lengths = [ |
|
s.layer_range[1] - s.layer_range[0] for s in definition.sources |
|
] |
|
if not all(s == slice_lengths[0] for s in slice_lengths): |
|
raise RuntimeError( |
|
"All inputs to a slice must contain the same number of layers" |
|
) |
|
num_layers = slice_lengths[0] |
|
|
|
cfg_reader = ConfigReader(config=self.config, slice_out=definition, t=0) |
|
for idx in range(num_layers): |
|
|
|
if num_layers > 1: |
|
t = idx / (num_layers - 1) |
|
else: |
|
t = 1 |
|
|
|
self.plan_layer( |
|
definition.sources, |
|
layer_offset=idx, |
|
t=t, |
|
cfg_reader=cfg_reader, |
|
) |
|
|
|
def plan(self): |
|
self.normalize_config() |
|
self._tasks = [] |
|
|
|
for weight_info in self.arch_info.pre_weights(config=self.out_model_config): |
|
self.plan_tensor( |
|
weight_info, |
|
[weight_info] * len(self.config.slices[0].sources), |
|
[s.model for s in self.config.slices[0].sources], |
|
ConfigReader( |
|
config=self.config, |
|
t=0, |
|
tensor_name=weight_info.name, |
|
).for_out_slice(self.config.slices[0]), |
|
) |
|
|
|
for out_slice in self.config.slices: |
|
self.plan_slice(out_slice) |
|
|
|
for weight_info in self.arch_info.post_weights(config=self.out_model_config): |
|
self.plan_tensor( |
|
weight_info, |
|
[weight_info] * len(self.config.slices[-1].sources), |
|
[s.model for s in self.config.slices[-1].sources], |
|
ConfigReader( |
|
config=self.config, |
|
t=1, |
|
tensor_name=weight_info.name, |
|
).for_out_slice(self.config.slices[-1]), |
|
) |
|
|
|
self._tasks.append( |
|
FinalizeModel( |
|
tensor_save_tasks=tuple(self._tasks), writer_task=self._writer_task |
|
) |
|
) |
|
res = list(self._tasks) |
|
if self._tokenizer_task: |
|
res.append(self._tokenizer_task) |
|
return res |
|
|