Spaces:
Sleeping
Sleeping
# Unsloth Zoo - Utilities for Unsloth | |
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. | |
# | |
# This program is free software: you can redistribute it and/or modify | |
# it under the terms of the GNU Lesser General Public License as published by | |
# the Free Software Foundation, either version 3 of the License, or | |
# (at your option) any later version. | |
# | |
# This program is distributed in the hope that it will be useful, | |
# but WITHOUT ANY WARRANTY; without even the implied warranty of | |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
# GNU General Public License for more details. | |
# | |
# You should have received a copy of the GNU Lesser General Public License | |
# along with this program. If not, see <https://www.gnu.org/licenses/>. | |
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False} | |
from torch import Tensor | |
import torch | |
from torch.nn import functional as F | |
from transformers.models.mllama.modeling_mllama import (F, nn) | |
def forward(self, input: Tensor) -> Tensor: | |
self._check_input_dim(input) | |
# exponential_average_factor is set to self.momentum | |
# (when it is available) only so that it gets updated | |
# in ONNX graph when this node is exported to ONNX. | |
if self.momentum is None: | |
exponential_average_factor = 0.0 | |
else: | |
exponential_average_factor = self.momentum | |
if self.training and self.track_running_stats: | |
# TODO: if statement only here to tell the jit to skip emitting this when it is None | |
if self.num_batches_tracked is not None: # type: ignore[has-type] | |
self.num_batches_tracked.add_(1) # type: ignore[has-type] | |
if self.momentum is None: # use cumulative moving average | |
exponential_average_factor = 1.0 / float(self.num_batches_tracked) | |
else: # use exponential moving average | |
exponential_average_factor = self.momentum | |
r""" | |
Decide whether the mini-batch stats should be used for normalization rather than the buffers. | |
Mini-batch stats are used in training mode, and in eval mode when buffers are None. | |
""" | |
if self.training: | |
bn_training = True | |
else: | |
bn_training = (self.running_mean is None) and (self.running_var is None) | |
r""" | |
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be | |
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are | |
used for normalization (i.e. in eval mode when buffers are not None). | |
""" | |
return F.batch_norm( | |
input, | |
# If buffers are not to be tracked, ensure that they won't be updated | |
self.running_mean | |
if not self.training or self.track_running_stats | |
else None, | |
self.running_var if not self.training or self.track_running_stats else None, | |
self.weight, | |
self.bias, | |
bn_training, | |
exponential_average_factor, | |
self.eps, | |
).to(input.dtype) | |