Spaces:
Running
Running
from typing import Any, Dict, Union | |
import torch | |
from packaging import version | |
from torch import nn | |
from transformers import ( | |
Trainer, | |
is_apex_available, | |
) | |
if is_apex_available(): | |
from apex import amp | |
if version.parse(torch.__version__) >= version.parse("1.6"): | |
_is_native_amp_available = True | |
from torch.cuda.amp import autocast | |
class CTCTrainer(Trainer): | |
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: | |
model.train() | |
inputs = self._prepare_inputs(inputs) | |
if self.use_amp: | |
with autocast(): | |
loss = self.compute_loss(model, inputs) | |
else: | |
loss = self.compute_loss(model, inputs) | |
if self.args.gradient_accumulation_steps > 1: | |
loss = loss / self.args.gradient_accumulation_steps | |
if self.use_amp: | |
self.scaler.scale(loss).backward() | |
elif self.use_apex: | |
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | |
scaled_loss.backward() | |
elif self.deepspeed: | |
self.deepspeed.backward(loss) | |
else: | |
loss.backward() | |
return loss.detach() |