Training utilities¶
ONNX¶
onnxcustom.utils.onnx_helper.add_initializer
(model, name, value)
Adds an initializer to graph.
onnxcustom.utils.onnx_helper.dtype_to_var_type
(dtype)
Converts a numpy dtype into a var type.
onnxcustom.utils.onnx_helper.get_onnx_opset
(onx, domain = ‘’)
Returns the opset associated to an opset.
onnxcustom.utils.orttraining_helper.get_train_initializer
(onx)
Returns the list of initializers to train.
onnxcustom.utils.onnx_helper.proto_type_to_dtype
(proto_type)
Converts a ONNX TensorProto type into numpy type.
onnxcustom.utils.onnx_helper.onnx_rename_weights
(onx)
Renames ONNX initializers to make sure their name follows the alphabetical order. The model is modified inplace. This function calls
onnx_rename_names
.
onnxcustom.utils.onnx_rewriter.onnx_rewrite_operator
(onx, op_type, sub_onx, recursive = True, debug_info = None)
Replaces one operator by an onnx graph.
onnxcustom.utils.onnx_helper.replace_initializers_into_onnx
(model, results)
Replaces initializers by other initializers, usually trained ones.
onnxruntime¶
onnxcustom.utils.onnxruntime_helper.device_to_providers
(device)
Returns the corresponding providers for a specific device.
onnxcustom.utils.onnxruntime_helper.numpy_to_ort_value
(arr, device = None)
Converts a numpy array to C_OrtValue.
onnxcustom.utils.onnxruntime_helper.get_ort_device
(device)
Converts device into C_OrtDevice.
onnxcustom.utils.onnxruntime_helper.get_ort_device_type
(device)
Converts device into device type.
onnxcustom.utils.onnxruntime_helper.ort_device_to_string
(device)
Returns a string representing the device. Opposite of function
get_ort_device
.
onnxcustom.utils.onnxruntime_helper.provider_to_device
(provider_name)
Converts provider into a device.
functions¶
onnxcustom.utils.orttraining_helper.add_loss_output
(onx, score_name = ‘squared_error’, loss_name = ‘loss’, label_name = ‘label’, weight_name = None, penalty = None, output_index = None, kwargs)
Modifies an ONNX graph to add operators to score and allow training.
onnxcustom.utils.onnx_function.get_supported_functions
()
Returns the list of supported function by
function_onnx_graph
.
onnxcustom.utils.onnx_function.function_onnx_graph
(name, target_opset = None, dtype = <class ‘numpy.float32’>, weight_name = None, kwargs)
Returns the ONNX graph corresponding to a function.
onnxcustom.utils.orttraining_helper.penalty_loss_onnx
(name, dtype, l1 = None, l2 = None, existing_names = None)
Returns onnx nodes to compute where and .
gradient¶
onnxcustom.training.grad_helper.onnx_derivative
(onx, weights = None, inputs = None, options = DerivativeOptions.Zero, loss = None, label = None, path_name = None)
Builds the gradient for an onnx graph.