# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from typing import List, Optional import torch from ._debug import map_debug_info def stage_backward( stage_output, output_grads, input_values, outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used ): """ This is a helper function to: 1. compute the gradients for the stage inputs, and 2. accumulate gradients for the stage module's parameters. Given the input value(s) and the corresponding gradient for the output value(s), compute and accumulate gradients for all parameter values (leaves in the autograd trace) as well as return a list of the gradients for the input values """ if outputs_with_grads_idxs is not None: # Deprecated, not used in runtime calls, only exists in compiler stage_output = [stage_output[i] for i in outputs_with_grads_idxs] output_grads = [output_grads[i] for i in outputs_with_grads_idxs] try: # stage_output may be a composite datatype like dict. Extract all individual # tensor values here stage_output_tensors = [] output_grad_tensors = [] def extract_tensors_with_grads(output_val, grad_val): if isinstance(output_val, torch.Tensor): if not output_val.requires_grad and output_val.grad_fn is None: return assert isinstance( grad_val, (torch.Tensor, type(None)) ), f"Expected Tensor or None gradient but got {type(grad_val)}" stage_output_tensors.append(output_val) output_grad_tensors.append(grad_val) elif isinstance(output_val, (tuple, list)): if grad_val is None: return assert isinstance( grad_val, (tuple, list) ), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" assert len(output_val) == len(grad_val) for ov, gv in zip(output_val, grad_val): extract_tensors_with_grads(ov, gv) elif isinstance(output_val, dict): if grad_val is None: return assert isinstance(grad_val, dict) assert set(output_val.keys()) == set(grad_val.keys()) for k in output_val.keys(): extract_tensors_with_grads(output_val[k], grad_val[k]) else: # Output is a non-tensor type; just ignore it pass extract_tensors_with_grads(stage_output, output_grads) torch.autograd.backward( stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type] ) # Extract gradients wrt the input values grad_inputs = [] for val in input_values: if isinstance(val, torch.Tensor): grad_inputs.append(val.grad) else: grad_inputs.append(None) # Alternative impl: `torch.autograd.grad`. # Note that `torch.autograd.grad` will not accumulate gradients into the # model's parameters. """ inputs_with_grad = [] for val in input_values: if isinstance(val, torch.Tensor) and val.requires_grad: inputs_with_grad.append(val) grad_inputs = torch.autograd.grad( stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type] ) """ except Exception as e: exc_msg = f""" Failed to run stage backward: Stage output: {map_debug_info(stage_output)} Output gradient: {map_debug_info(output_grads)} Input: {map_debug_info(input_values)} """ raise RuntimeError(exc_msg) from e return grad_inputs # TODO: handling requires_grad=False dynamically. Can we analyze this during initial # IR emission? def _null_coalesce_accumulate(lhs, rhs): """ Coalesce two values, even if one of them is null, returning the non-null value. """ if lhs is None: return rhs elif rhs is None: return lhs else: return torch.add(lhs, rhs)