Partial Training¶
OrtValueCache¶
- class onnxruntime.capi._pybind_state.OrtValueCache(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None ¶
- __init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None ¶
- count(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str) int ¶
- insert(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValue) None ¶
- remove(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str) None ¶
TrainingAgent¶
- class onnxruntime.capi._pybind_state.TrainingAgent(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg1: List[str], arg2: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg3: List[str], arg4: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice]) None ¶
This is the main class used to run a ORTModule model.
- __init__(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg1: List[str], arg2: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg3: List[str], arg4: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice]) None ¶
- run_backward(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg2: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState) None ¶
- run_forward(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg2: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState, arg3: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None ¶
PartialGraphExecutionState¶
- class onnxruntime.capi._pybind_state.PartialGraphExecutionState(self: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState) None ¶
OrtModuleGraphBuilder¶
- class onnxruntime.capi._pybind_state.OrtModuleGraphBuilder(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) None ¶
- __init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) None ¶
- build(*args, **kwargs)¶
Overloaded function.
build(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) -> None
build(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder, arg0: List[List[int]]) -> None
- get_graph_info(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) onnxruntime.capi.onnxruntime_pybind11_state.GraphInfo ¶
- get_inference_optimized_model(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) bytes ¶
- get_model(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) bytes ¶
OrtModuleGraphBuilderConfiguration¶
- class onnxruntime.capi._pybind_state.OrtModuleGraphBuilderConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration) None ¶
Configuration information for module graph builder.
- __init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration) None ¶
- property build_gradient_graph¶
- property enable_caching¶
- property graph_transformer_config¶
- property initializer_names¶
- property initializer_names_to_train¶
- property input_names_require_grad¶
- property loglevel¶
- property use_memory_efficient_gradient¶
TrainingGraphTransformerConfiguration¶
- class onnxruntime.capi._pybind_state.TrainingGraphTransformerConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingGraphTransformerConfiguration) None ¶
Training Graph transformer configuration.
- __init__(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingGraphTransformerConfiguration) None ¶
- property allow_layer_norm_mod_precision¶
- property attn_dropout_recompute¶
- property enable_gelu_approximation¶
- property gelu_recompute¶
- property number_recompute_layers¶
- property propagate_cast_ops_config¶
- property transformer_layer_recompute¶