File size: 19,485 Bytes
f14e74e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 |
# Copyright © 2023 Apple Inc.
import textwrap
from typing import Any, Callable, List, Optional, Tuple, Union
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
class Module(dict):
"""Base class for building neural networks with MLX.
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
your models should do the same.
A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
instances in arbitrary nesting of python lists or dicts. The ``Module``
then allows recursively extracting all the :class:`mlx.core.array` instances
using :meth:`mlx.nn.Module.parameters`.
In addition, the ``Module`` has the concept of trainable and non trainable
parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
the gradients are returned only with respect to the trainable parameters.
All arrays in a module are trainable unless they are added in the "frozen"
set by calling :meth:`freeze`.
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
class MyMLP(nn.Module):
def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
super().__init__()
self.in_proj = nn.Linear(in_dims, hidden_dims)
self.out_proj = nn.Linear(hidden_dims, out_dims)
def __call__(self, x):
x = self.in_proj(x)
x = mx.maximum(x, 0)
return self.out_proj(x)
model = MyMLP(2, 1)
# All the model parameters are created but since MLX is lazy by
# default, they are not evaluated yet. Calling `mx.eval` actually
# allocates memory and initializes the parameters.
mx.eval(model.parameters())
# Setting a parameter to a new value is as simply as accessing that
# parameter and assigning a new array to it.
model.in_proj.weight = model.in_proj.weight * 2
mx.eval(model.parameters())
"""
def __init__(self):
"""Should be called by the subclasses of ``Module``."""
self._no_grad = set()
self._training = True
@property
def training(self):
"""Boolean indicating if the model is in training mode."""
return self._training
def _extra_repr(self):
return ""
def __repr__(self):
children = tree_flatten(self.children(), is_leaf=self.is_module)
value = f"{type(self).__name__}({self._extra_repr()}"
for k, v in children:
value += "\n"
value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ")
if children:
value += "\n"
value += ")"
return value
def __getattr__(self, key: str):
if key in self:
return self[key]
else:
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
def __setattr__(self, key: str, val: Any):
self[key] = val
def load_weights(
self,
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
strict: bool = True,
):
"""
Update the model's weights from a ``.npz`` or a list.
Args:
file_or_weights (str or list(tuple(str, mx.array))): The path to
the weights ``.npz`` file or a list of pairs of parameter names
and arrays.
strict (bool, optional): If ``True`` then checks that the provided
weights exactly match the parameters of the model. Otherwise,
only the weights actually contained in the model are loaded and
shapes are not checked. Default: ``True``.
Example:
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
model = nn.Linear(10, 10)
# Load from file
model.load_weights("weights.npz")
# Load from list
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
("bias", mx.zeros((10,))),
]
model.load_weights(weights)
# Missing weight
weights = [
("weight", mx.random.uniform(shape=(10, 10))),
]
# Raises a ValueError exception
model.load_weights(weights)
# Ok, only updates the weight but not the bias
model.load_weights(weights, strict=False)
"""
weights = file_or_weights
if isinstance(weights, str):
weights = list(mx.load(weights).items())
if strict:
new_weights = dict(weights)
curr_weights = dict(tree_flatten(self.parameters()))
if extras := (new_weights.keys() - curr_weights.keys()):
extras = " ".join(extras)
raise ValueError(f"Received parameters not in model: {extras}.")
if missing := (curr_weights.keys() - new_weights.keys()):
missing = " ".join(missing)
raise ValueError(f"Missing parameters: {missing}.")
for k, v in curr_weights.items():
v_new = new_weights[k]
if not isinstance(v_new, mx.array):
raise ValueError(
"Expected mx.array but received "
f"{type(v_new)} for parameter {k}"
)
if v_new.shape != v.shape:
raise ValueError(
f"Expected shape {v.shape} but received "
f" shape {v_new.shape} for parameter {k}"
)
self.update(tree_unflatten(weights))
def save_weights(self, file: str):
"""
Save the model's weights to a ``.npz`` file.
"""
mx.savez(file, **dict(tree_flatten(self.parameters())))
@staticmethod
def is_module(value):
return isinstance(value, Module)
@staticmethod
def valid_child_filter(module, key, value):
return isinstance(value, (dict, list))
@staticmethod
def valid_parameter_filter(module, key, value):
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
@staticmethod
def trainable_parameter_filter(module, key, value):
return (
Module.valid_parameter_filter(module, key, value)
and key not in module._no_grad
)
def filter_and_map(
self,
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
map_fn: Optional[Callable] = None,
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
):
"""Recursively filter the contents of the module using ``filter_fn``,
namely only select keys and values where ``filter_fn`` returns true.
This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
but it can also be used to extract any subset of the module's parameters.
Args:
filter_fn (Callable): Given a value, the key in which it is found
and the containing module, decide whether to keep the value or
drop it.
map_fn (Callable, optional): Optionally transform the value before
returning it.
is_leaf_fn (Callable, optional): Given a value, the key in which it
is found and the containing module decide if it is a leaf.
Returns:
A dictionary containing the contents of the module recursively filtered
"""
map_fn = map_fn or (lambda x: x)
is_leaf_fn = is_leaf_fn or (
lambda m, k, v: not isinstance(v, (Module, dict, list))
)
def unwrap(vk, v):
if is_leaf_fn(self, vk, v):
return map_fn(v)
if isinstance(v, Module):
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
if isinstance(v, dict):
nd = {}
for k, v in v.items():
tk = f"{vk}.{k}"
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
return nd
if isinstance(v, list):
nl = []
for i, vi in enumerate(v):
tk = f"{vk}.{i}"
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
return nl
raise RuntimeError("Unexpected leaf found while traversing the module")
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
def parameters(self):
"""Recursively return all the :class:`mlx.core.array` members of this Module
as a dict of dicts and lists."""
return self.filter_and_map(self.valid_parameter_filter)
def trainable_parameters(self):
"""Recursively return all the non frozen :class:`mlx.core.array` members of
this Module as a dict of dicts and lists."""
return self.filter_and_map(self.trainable_parameter_filter)
def children(self):
"""Return the direct descendants of this Module instance."""
return self.filter_and_map(
self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
)
def leaf_modules(self):
"""Return the submodules that do not contain other modules."""
def _is_leaf_module(m, k, v):
return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict):
"""Replace the parameters of this Module with the provided ones in the
dict of dicts and lists.
Commonly used by the optimizer to change the model to the updated
(optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
tracers in the model in order to compute gradients.
The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be
updated.
Args:
parameters (dict): A complete or partial dictionary of the modules
parameters.
"""
def apply(dst, parameters):
if isinstance(parameters, dict):
for k in parameters:
if k in dst:
current_value = dst[k]
new_value = parameters[k]
if isinstance(current_value, mx.array):
dst[k] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif isinstance(parameters, list):
for i in range(len(dst)):
current_value = dst[i]
new_value = parameters[i]
if isinstance(current_value, mx.array):
dst[i] = new_value
elif isinstance(current_value, Module):
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
apply(self, parameters)
def apply(
self,
map_fn: Callable[[mx.array], mx.array],
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
):
"""Map all the parameters using the provided ``map_fn`` and immediately
update the module with the mapped parameters.
For instance running ``model.apply(lambda x: x.astype(mx.float16))``
casts all parameters to 16 bit floats.
Args:
map_fn (Callable): Maps an array to another array
filter_fn (Callable, optional): Filter to select which arrays to
map (default: :meth:`Module.valid_parameter_filter`).
"""
filter_fn = filter_fn or Module.valid_parameter_filter
self.update(self.filter_and_map(filter_fn, map_fn))
def update_modules(self, modules: dict):
"""Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists.
It is the equivalent of :meth:`Module.update` but for modules instead
of parameters and allows us to flexibly edit complex architectures by
programmatically swapping layers.
The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be
updated.
Args:
modules (dict): A complete or partial dictionary of the modules
submodules.
"""
def apply(dst, modules):
if isinstance(modules, dict):
for k in modules:
if k in dst:
current_value = dst[k]
new_value = modules[k]
if self.is_module(current_value) and self.is_module(new_value):
dst[k] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
elif isinstance(modules, list):
for i in range(len(dst)):
current_value = dst[i]
new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value):
dst[i] = new_value
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value)
apply(self, modules)
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
"""Apply a function to all the modules in this instance (including this
instance).
Args:
apply_fn (Callable): The function to apply to the modules.
"""
module_stack = [("", self)]
while module_stack:
prefix, mod = module_stack.pop()
apply_fn(prefix, mod)
prefix = "." + prefix if prefix else ""
module_stack.extend(
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
)
def modules(self):
"""Return a list with all the modules in this instance.
Returns:
A list of :class:`mlx.nn.Module` instances.
"""
modulelist = []
self.apply_to_modules(lambda k, m: modulelist.append(m))
return modulelist
def named_modules(self):
"""Return a list with all the modules in this instance and their name
with dot notation.
Returns:
A list of tuples (str, :class:`mlx.nn.Module`).
"""
modulelist = []
self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
return modulelist
def _validate_keys(self, keys, strict):
keys = keys if isinstance(keys, list) else [keys]
if strict:
for k in keys:
if k not in self:
raise KeyError(f"Module doesn't contain member {k}.")
return keys
def freeze(
self,
*,
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
computing gradients for it.
This function is idempotent i.e. freezing a frozen model is a no-op.
Example:
For instance to only train the attention parameters from a Transformer:
.. code-block:: python
model = nn.Transformer()
model.freeze()
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
Args:
recurse (bool, optional): If True then freeze the parameters of the
submodules as well. Default: ``True``.
keys (str or list[str], optional): If provided then only these
parameters will be frozen otherwise all the parameters of a
module. For instance freeze all biases by calling
``module.freeze(keys="bias")``.
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``.
"""
def _freeze_impl(_, m):
local_keys = keys
if local_keys is None:
local_keys = tree_flatten(
m.filter_and_map(
lambda m, k, v: (not isinstance(v, Module))
and m.valid_parameter_filter(m, k, v)
)
)
local_keys = [k for (k, v) in local_keys]
local_keys = m._validate_keys(local_keys, strict)
m._no_grad.update(local_keys)
if recurse:
self.apply_to_modules(_freeze_impl)
else:
_freeze_impl("", self)
def unfreeze(
self,
*,
recurse: bool = True,
keys: Optional[Union[str, List[str]]] = None,
strict: bool = False,
):
"""Unfreeze the Module's parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is
a noop.
Example:
For instance to only train the biases of a Transformer one can do:
.. code-block:: python
model = nn.Transformer()
model.freeze()
model.unfreeze(keys="bias")
Args:
recurse (bool, optional): If True then unfreeze the parameters of the
submodules as well. Default: ``True``.
keys (str or list[str], optional): If provided then only these
parameters will be unfrozen otherwise all the parameters of a
module. For instance unfreeze all biases by calling
``module.unfreeze(keys="bias")``.
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
Default: ``False``.
"""
def _unfreeze_impl(_, m):
if keys is None:
m._no_grad.clear()
else:
local_keys = m._validate_keys(keys, strict)
m._no_grad.difference_update(local_keys)
if recurse:
self.apply_to_modules(_unfreeze_impl)
else:
_unfreeze_impl("", self)
def train(self, mode: bool = True):
"""Set the model in or out of training mode.
Training mode only applies to certain layers. For example
:obj:`Dropout` applies a random mask in training mode, but is the
identity in evaluation mode.
Args:
mode (bool): Indicate if the model should be in training or
evaluation mode. Default: ``True``.
"""
def _set_train(_, m):
m._training = mode
self.apply_to_modules(_set_train)
def eval(self):
"""Set the model to evaluation mode.
See :func:`train`.
"""
self.train(False)
|