|
import math |
|
import warnings |
|
from typing import Union, Tuple, Optional |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutput |
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
|
from transformers.integrations.fsdp import is_fsdp_managed_module |
|
from transformers.models.hubert.modeling_hubert import ( |
|
HubertFeatureEncoder, |
|
HubertFeatureProjection, |
|
HubertEncoderStableLayerNorm, |
|
HubertEncoder, |
|
_HIDDEN_STATES_START_POSITION |
|
) |
|
|
|
from .configuration_hubert_spkreg import HubertSpkRegConfig |
|
|
|
|
|
class HubertSpkRegPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = HubertSpkRegConfig |
|
base_model_prefix = "hubert" |
|
main_input_name = "input_values" |
|
supports_gradient_checkpointing = True |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, nn.Conv1d): |
|
if is_deepspeed_zero3_enabled(): |
|
import deepspeed |
|
|
|
if hasattr(module, "weight_v") and hasattr(module, "weight_g"): |
|
with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): |
|
nn.init.kaiming_normal_(module.weight.data) |
|
else: |
|
with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): |
|
nn.init.kaiming_normal_(module.weight.data) |
|
else: |
|
nn.init.kaiming_normal_(module.weight.data) |
|
|
|
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): |
|
""" |
|
Computes the output length of the convolutional layers |
|
""" |
|
|
|
def _conv_out_length(input_length, kernel_size, stride): |
|
|
|
|
|
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 |
|
|
|
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): |
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
|
|
|
return input_lengths |
|
|
|
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): |
|
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) |
|
batch_size = attention_mask.shape[0] |
|
|
|
attention_mask = torch.zeros( |
|
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device |
|
) |
|
|
|
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 |
|
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() |
|
return attention_mask |
|
|
|
|
|
|
|
def _compute_mask_indices( |
|
shape: Tuple[int, int], |
|
mask_prob: float, |
|
mask_length: int, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
min_masks: int = 0, |
|
) -> np.ndarray: |
|
""" |
|
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for |
|
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on |
|
CPU as part of the preprocessing during training. |
|
|
|
Args: |
|
shape: The shape for which to compute masks. This should be of a tuple of size 2 where |
|
the first element is the batch size and the second element is the length of the axis to span. |
|
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of |
|
independently generated mask spans of length `mask_length` is computed by |
|
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the |
|
actual percentage will be smaller. |
|
mask_length: size of the mask |
|
min_masks: minimum number of masked spans |
|
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of |
|
each batch dimension. |
|
""" |
|
batch_size, sequence_length = shape |
|
|
|
if mask_length < 1: |
|
raise ValueError("`mask_length` has to be bigger than 0.") |
|
|
|
if mask_length > sequence_length: |
|
raise ValueError( |
|
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" |
|
f" and `sequence_length`: {sequence_length}`" |
|
) |
|
|
|
|
|
epsilon = np.random.rand(1).item() |
|
|
|
def compute_num_masked_span(input_length): |
|
"""Given input length, compute how many spans should be masked""" |
|
num_masked_span = int(mask_prob * input_length / mask_length + epsilon) |
|
num_masked_span = max(num_masked_span, min_masks) |
|
|
|
|
|
if num_masked_span * mask_length > sequence_length: |
|
num_masked_span = sequence_length // mask_length |
|
|
|
|
|
if input_length - (mask_length - 1) < num_masked_span: |
|
num_masked_span = max(input_length - (mask_length - 1), 0) |
|
|
|
return num_masked_span |
|
|
|
|
|
input_lengths = ( |
|
attention_mask.sum(-1).detach().tolist() |
|
if attention_mask is not None |
|
else [sequence_length for _ in range(batch_size)] |
|
) |
|
|
|
|
|
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) |
|
spec_aug_mask_idxs = [] |
|
|
|
max_num_masked_span = compute_num_masked_span(sequence_length) |
|
|
|
if max_num_masked_span == 0: |
|
return spec_aug_mask |
|
|
|
for input_length in input_lengths: |
|
|
|
num_masked_span = compute_num_masked_span(input_length) |
|
|
|
|
|
spec_aug_mask_idx = np.random.choice( |
|
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False |
|
) |
|
|
|
|
|
|
|
|
|
if len(spec_aug_mask_idx) == 0: |
|
|
|
|
|
|
|
dummy_mask_idx = sequence_length - 1 |
|
else: |
|
dummy_mask_idx = spec_aug_mask_idx[0] |
|
|
|
spec_aug_mask_idx = np.concatenate( |
|
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] |
|
) |
|
spec_aug_mask_idxs.append(spec_aug_mask_idx) |
|
|
|
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) |
|
|
|
|
|
spec_aug_mask_idxs = np.broadcast_to( |
|
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) |
|
) |
|
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) |
|
|
|
|
|
offsets = np.arange(mask_length)[None, None, :] |
|
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( |
|
batch_size, max_num_masked_span * mask_length |
|
) |
|
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets |
|
|
|
|
|
if spec_aug_mask_idxs.max() > sequence_length - 1: |
|
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 |
|
|
|
|
|
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) |
|
|
|
return spec_aug_mask |
|
|
|
|
|
class HubertSpkRegModel(HubertSpkRegPreTrainedModel): |
|
|
|
def __init__(self, config: HubertSpkRegConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.feature_extractor = HubertFeatureEncoder(config) |
|
self.feature_projection = HubertFeatureProjection(config) |
|
|
|
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: |
|
self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) |
|
|
|
if config.do_stable_layer_norm: |
|
self.encoder = HubertEncoderStableLayerNorm(config) |
|
else: |
|
self.encoder = HubertEncoder(config) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _mask_hidden_states( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
mask_time_indices: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
): |
|
""" |
|
Masks extracted features along time axis and/or along feature axis according to |
|
[SpecAugment](https://arxiv.org/abs/1904.08779). |
|
""" |
|
|
|
|
|
if not getattr(self.config, "apply_spec_augment", True): |
|
return hidden_states |
|
|
|
|
|
batch_size, sequence_length, hidden_size = hidden_states.size() |
|
|
|
if mask_time_indices is not None: |
|
|
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) |
|
elif self.config.mask_time_prob > 0 and self.training: |
|
mask_time_indices = _compute_mask_indices( |
|
(batch_size, sequence_length), |
|
mask_prob=self.config.mask_time_prob, |
|
mask_length=self.config.mask_time_length, |
|
attention_mask=attention_mask, |
|
min_masks=self.config.mask_time_min_masks, |
|
) |
|
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) |
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) |
|
|
|
if self.config.mask_feature_prob > 0 and self.training: |
|
|
|
mask_feature_indices = _compute_mask_indices( |
|
(batch_size, hidden_size), |
|
mask_prob=self.config.mask_feature_prob, |
|
mask_length=self.config.mask_feature_length, |
|
min_masks=self.config.mask_feature_min_masks, |
|
) |
|
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) |
|
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) |
|
hidden_states[mask_feature_indices] = 0 |
|
|
|
return hidden_states |
|
|
|
def forward( |
|
self, |
|
input_values: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
mask_time_indices: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutput]: |
|
""" |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoProcessor, HubertModel |
|
>>> from datasets import load_dataset |
|
>>> import soundfile as sf |
|
|
|
>>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") |
|
>>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") |
|
|
|
|
|
>>> def map_to_array(batch): |
|
... speech, _ = sf.read(batch["file"]) |
|
... batch["speech"] = speech |
|
... return batch |
|
|
|
|
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
|
>>> ds = ds.map(map_to_array) |
|
|
|
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1 |
|
>>> hidden_states = model(input_values).last_hidden_state |
|
```""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
extract_features = self.feature_extractor(input_values) |
|
extract_features = extract_features.transpose(1, 2) |
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) |
|
|
|
hidden_states = self.feature_projection(extract_features) |
|
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) |
|
|
|
encoder_outputs = self.encoder( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = encoder_outputs[0] |
|
|
|
if not return_dict: |
|
return (hidden_states,) + encoder_outputs[1:] |
|
|
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class AngularLinear(nn.Module): |
|
|
|
def __init__(self, in_features: int, out_features: int): |
|
super(AngularLinear, self).__init__() |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.weight = torch.nn.Parameter( |
|
torch.FloatTensor(out_features, in_features), requires_grad=True |
|
) |
|
nn.init.xavier_normal_(self.weight, gain=1) |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
): |
|
|
|
cosine = F.linear(F.normalize(inputs), F.normalize(self.weight)) |
|
return cosine |
|
|
|
def extra_repr(self) -> str: |
|
return 'in_features={}, out_features={}'.format( |
|
self.in_features, self.out_features |
|
) |
|
|
|
|
|
class AMSoftmaxLoss(nn.Module): |
|
"""Additive Margin Softmax (CosFace). |
|
|
|
Paper: Wang, Feng, et al. "Additive margin softmax for face verification." |
|
IEEE Signal Processing Letters 25.7 (2018): 926-930. |
|
""" |
|
def __init__( |
|
self, |
|
scale: float = 30.0, |
|
margin: float = 0.35, |
|
label_smoothing: float = 0.0, |
|
reduction: str = "mean" |
|
): |
|
""" |
|
Args: |
|
num_classes: Number of classes (output dimension) |
|
scale: Scaling factor for logits (default: 30.0) |
|
margin: Angular margin (default: 0.35) |
|
""" |
|
super(AMSoftmaxLoss, self).__init__() |
|
self.scale = scale |
|
self.margin = margin |
|
self.label_smoothing = label_smoothing |
|
self.reduction = reduction |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
): |
|
""" |
|
Args: |
|
inputs: Input features of shape (batch_size, num_labels) |
|
targets: Ground truth labels of shape (batch_size) |
|
label_smoothing: Label smoothing factor (default: 0.0) |
|
reduction: Reduction method (default: "mean") |
|
Returns: |
|
Loss value |
|
""" |
|
_, num_labels = inputs.shape |
|
|
|
cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7) |
|
psi = cos_theta - self.margin |
|
one_hot = nn.functional.one_hot(targets, num_labels) |
|
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) |
|
loss = F.cross_entropy( |
|
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction |
|
) |
|
return loss |
|
|
|
|
|
class AAMSoftmaxLoss(nn.Module): |
|
"""Additive Angular Margin Softmax (ArcFace). |
|
|
|
Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition." |
|
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019. |
|
""" |
|
def __init__( |
|
self, |
|
scale: float = 30.0, |
|
margin: float = 0.2, |
|
easy_margin: bool = False, |
|
label_smoothing: float = 0.0, |
|
reduction: str = "mean" |
|
): |
|
""" |
|
Args: |
|
num_classes: Number of classes (output dimension) |
|
scale: Scaling factor for logits (default: 30.0) |
|
margin: Angular margin (default: 0.35) |
|
easy_margin: Use the easy margin loss (default: False) |
|
""" |
|
super(AAMSoftmaxLoss, self).__init__() |
|
self.scale = scale |
|
self.margin = margin |
|
self.easy_margin = easy_margin |
|
self.label_smoothing = label_smoothing |
|
self.reduction = reduction |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
): |
|
""" |
|
Args: |
|
inputs: Input features of shape (batch_size, num_labels) |
|
targets: Ground truth labels of shape (batch_size) |
|
Returns: |
|
Loss value |
|
""" |
|
_, num_labels = inputs.shape |
|
|
|
epsilon = 1e-6 |
|
|
|
|
|
cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon) |
|
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) |
|
sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon) |
|
|
|
cos_m = math.cos(self.margin) |
|
sin_m = math.sin(self.margin) |
|
psi = cos_theta * cos_m - sin_theta * sin_m |
|
|
|
if self.easy_margin: |
|
psi = torch.where(cos_theta > 0, psi, cos_theta) |
|
else: |
|
|
|
psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin) |
|
|
|
one_hot = nn.functional.one_hot(targets, num_labels) |
|
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) |
|
loss = F.cross_entropy( |
|
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction |
|
) |
|
return loss |
|
|
|
|
|
class HubertSpkRegForSequenceClassification(HubertSpkRegPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
if hasattr(config, "add_adapter") and config.add_adapter: |
|
raise ValueError( |
|
"Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)" |
|
) |
|
self.hubert = HubertSpkRegModel(config) |
|
num_layers = config.num_hidden_layers + 1 |
|
if config.use_weighted_layer_sum: |
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
|
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) |
|
|
|
if self.config.loss_fct == 'cross_entropy': |
|
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) |
|
elif self.config.loss_fct == 'additive_margin': |
|
self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels) |
|
elif self.config.loss_fct == 'additive_angular_margin': |
|
self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels) |
|
else: |
|
raise ValueError(f"Unsupported loss function: {self.config.loss_fct}") |
|
|
|
|
|
self.post_init() |
|
|
|
def freeze_feature_extractor(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameters will |
|
not be updated during training. |
|
""" |
|
warnings.warn( |
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " |
|
"Please use the equivalent `freeze_feature_encoder` method instead.", |
|
FutureWarning, |
|
) |
|
self.freeze_feature_encoder() |
|
|
|
def freeze_feature_encoder(self): |
|
""" |
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
not be updated during training. |
|
""" |
|
self.hubert.feature_extractor._freeze_parameters() |
|
|
|
def freeze_base_model(self): |
|
""" |
|
Calling this function will disable the gradient computation for the base model so that its parameters will not |
|
be updated during training. Only the classification head will be updated. |
|
""" |
|
for param in self.hubert.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward( |
|
self, |
|
input_values: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
) -> Union[Tuple, SequenceClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states |
|
|
|
outputs = self.hubert( |
|
input_values, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if self.config.use_weighted_layer_sum: |
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION] |
|
hidden_states = torch.stack(hidden_states, dim=1) |
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) |
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
|
else: |
|
hidden_states = outputs[0] |
|
|
|
hidden_states = self.projector(hidden_states) |
|
if attention_mask is None: |
|
pooled_output = hidden_states.mean(dim=1) |
|
else: |
|
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) |
|
hidden_states[~padding_mask] = 0.0 |
|
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) |
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.config.loss_fct == 'cross_entropy': |
|
loss_fct = nn.CrossEntropyLoss( |
|
label_smoothing=self.config.label_smoothing, |
|
reduction=self.config.reduction |
|
) |
|
elif self.config.loss_fct == 'additive_margin': |
|
loss_fct = AMSoftmaxLoss( |
|
scale=self.config.scale, |
|
margin=self.config.margin, |
|
label_smoothing=self.config.label_smoothing, |
|
reduction=self.config.reduction |
|
) |
|
elif self.config.loss_fct == 'additive_angular_margin': |
|
loss_fct = AAMSoftmaxLoss( |
|
scale=self.config.scale, |
|
margin=self.config.margin, |
|
easy_margin=self.config.easy_margin, |
|
label_smoothing=self.config.label_smoothing, |
|
reduction=self.config.reduction |
|
) |
|
loss = loss_fct( |
|
logits.view(-1, self.config.num_labels), |
|
labels.view(-1), |
|
) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |