Gradient

C++ API

class onnxruntime.capi._pybind_state.GradientGraphBuilder(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientGraphBuilder, arg0: bytes, arg1: Set[str], arg2: Set[str], arg3: str) None

A utility for making a gradient graph that can be used to help train a model.

class onnxruntime.capi._pybind_state.GradientNodeAttributeDefinition(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeAttributeDefinition) None

Attribute definition for gradient graph nodes.

class onnxruntime.capi._pybind_state.GradientNodeDefinition(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeDefinition) None

Definition for gradient graph nodes.

onnxruntime.capi._pybind_state.register_gradient_definition(arg0: str, arg1: List[onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeDefinition]) None
onnxruntime.capi._pybind_state.register_aten_op_executor(arg0: str, arg1: str) None
onnxruntime.capi._pybind_state.register_backward_runner(arg0: object) None
onnxruntime.capi._pybind_state.register_forward_runner(arg0: object) None

Python API

onnxruntime.training.experimental.gradient_graph._gradient_graph_tools.export_gradient_graph(model: torch.nn.modules.module.Module, loss_fn: Callable[[Any, Any], Any], example_input: torch.Tensor, example_labels: torch.Tensor, gradient_graph_path: Union[pathlib.Path, str], opset_version=12) None

Build a gradient graph for model so that you can output gradients in an inference session when given specific input and corresponding labels.

Parameters
  • model (torch.nn.Module) – A gradient graph will be built for this model.

  • loss_fn (Callable[[Any, Any], Any]) – A function to compute the loss given the model’s output and the example_labels.

  • Web (Predefined loss functions such as torch.nn.CrossEntropyLoss() will work but you might not be able to load the graph in other environments such as an InferenceSession in ONNX Runtime) –

  • instead

  • method. (use a custom Python) –

  • example_input (torch.Tensor) – Example input that you would give your model for inference/prediction.

  • example_labels (torch.Tensor) – The expected labels for example_input.

  • different (This could be the output of your model when given example_input but it might be different if your loss function expects labels to be) –

  • gradient_graph_path (Union[Path, str]) – The path to where you would like to save the gradient graph.

  • opset_version (int) – See torch.onnx.export.