Partial Training

OrtValueCache

class onnxruntime.capi._pybind_state.OrtValueCache(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None
__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None
clear(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None
count(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str) int
insert(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValue) None
keys(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) list
remove(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str) None

TrainingAgent

class onnxruntime.capi._pybind_state.TrainingAgent(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg1: List[str], arg2: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg3: List[str], arg4: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice]) None

This is the main class used to run a ORTModule model.

__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg1: List[str], arg2: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg3: List[str], arg4: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice]) None
run_backward(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg2: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState) None
run_forward(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg2: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState, arg3: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None

PartialGraphExecutionState

class onnxruntime.capi._pybind_state.PartialGraphExecutionState(self: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState) None
__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState) None

OrtModuleGraphBuilder

class onnxruntime.capi._pybind_state.OrtModuleGraphBuilder(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) None
__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) None
build(*args, **kwargs)

Overloaded function.

  1. build(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) -> None

  2. build(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder, arg0: List[List[int]]) -> None

get_graph_info(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) onnxruntime.capi.onnxruntime_pybind11_state.GraphInfo
get_inference_optimized_model(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) bytes
get_model(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) bytes
initialize(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder, arg0: bytes, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration) None

OrtModuleGraphBuilderConfiguration

class onnxruntime.capi._pybind_state.OrtModuleGraphBuilderConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration) None

Configuration information for module graph builder.

__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration) None
property build_gradient_graph
property enable_caching
property graph_transformer_config
property initializer_names
property initializer_names_to_train
property input_names_require_grad
property loglevel
property use_memory_efficient_gradient

TrainingGraphTransformerConfiguration

class onnxruntime.capi._pybind_state.TrainingGraphTransformerConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingGraphTransformerConfiguration) None

Training Graph transformer configuration.

__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingGraphTransformerConfiguration) None
property allow_layer_norm_mod_precision
property attn_dropout_recompute
property enable_gelu_approximation
property gelu_recompute
property number_recompute_layers
property propagate_cast_ops_config
property transformer_layer_recompute