Gradient¶
C++ API¶
- class onnxruntime.capi._pybind_state.GradientGraphBuilder(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientGraphBuilder, arg0: bytes, arg1: Set[str], arg2: Set[str], arg3: str) None ¶
A utility for making a gradient graph that can be used to help train a model.
- class onnxruntime.capi._pybind_state.GradientNodeAttributeDefinition(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeAttributeDefinition) None ¶
Attribute definition for gradient graph nodes.
- class onnxruntime.capi._pybind_state.GradientNodeDefinition(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeDefinition) None ¶
Definition for gradient graph nodes.
- onnxruntime.capi._pybind_state.register_gradient_definition(arg0: str, arg1: List[onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeDefinition]) None ¶
Python API¶
- onnxruntime.training.experimental.gradient_graph._gradient_graph_tools.export_gradient_graph(model: torch.nn.modules.module.Module, loss_fn: Callable[[Any, Any], Any], example_input: torch.Tensor, example_labels: torch.Tensor, gradient_graph_path: Union[pathlib.Path, str], opset_version=12) None ¶
Build a gradient graph for model so that you can output gradients in an inference session when given specific input and corresponding labels.
- Parameters
model (torch.nn.Module) – A gradient graph will be built for this model.
loss_fn (Callable[[Any, Any], Any]) – A function to compute the loss given the model’s output and the example_labels.
Web (Predefined loss functions such as torch.nn.CrossEntropyLoss() will work but you might not be able to load the graph in other environments such as an InferenceSession in ONNX Runtime) –
instead –
method. (use a custom Python) –
example_input (torch.Tensor) – Example input that you would give your model for inference/prediction.
example_labels (torch.Tensor) – The expected labels for example_input.
different (This could be the output of your model when given example_input but it might be different if your loss function expects labels to be) –
gradient_graph_path (Union[Path, str]) – The path to where you would like to save the gradient graph.
opset_version (int) – See torch.onnx.export.