Training utilities¶
ONNX¶
onnxcustom.utils.onnx_helper.add_initializer (model, name, value)
Adds an initializer to graph.
onnxcustom.utils.onnx_helper.dtype_to_var_type (dtype)
Converts a numpy dtype into a var type.
onnxcustom.utils.onnx_helper.get_onnx_opset (onx, domain = ‘’)
Returns the opset associated to an opset.
onnxcustom.utils.orttraining_helper.get_train_initializer (onx)
Returns the list of initializers to train.
onnxcustom.utils.onnx_helper.proto_type_to_dtype (proto_type)
Converts a ONNX TensorProto type into numpy type.
onnxcustom.utils.onnx_helper.onnx_rename_weights (onx)
Renames ONNX initializers to make sure their name follows the alphabetical order. The model is modified inplace. This function calls
onnx_rename_names.
onnxcustom.utils.onnx_rewriter.onnx_rewrite_operator (onx, op_type, sub_onx, recursive = True, debug_info = None)
Replaces one operator by an onnx graph.
onnxcustom.utils.onnx_helper.replace_initializers_into_onnx (model, results)
Replaces initializers by other initializers, usually trained ones.
onnxruntime¶
onnxcustom.utils.onnxruntime_helper.device_to_providers (device)
Returns the corresponding providers for a specific device.
onnxcustom.utils.onnxruntime_helper.numpy_to_ort_value (arr, device = None)
Converts a numpy array to C_OrtValue.
onnxcustom.utils.onnxruntime_helper.get_ort_device (device)
Converts device into C_OrtDevice.
onnxcustom.utils.onnxruntime_helper.get_ort_device_type (device)
Converts device into device type.
onnxcustom.utils.onnxruntime_helper.ort_device_to_string (device)
Returns a string representing the device. Opposite of function
get_ort_device.
onnxcustom.utils.onnxruntime_helper.provider_to_device (provider_name)
Converts provider into a device.
functions¶
onnxcustom.utils.orttraining_helper.add_loss_output (onx, score_name = ‘squared_error’, loss_name = ‘loss’, label_name = ‘label’, weight_name = None, penalty = None, output_index = None, kwargs)
Modifies an ONNX graph to add operators to score and allow training.
onnxcustom.utils.onnx_function.get_supported_functions ()
Returns the list of supported function by
function_onnx_graph.
onnxcustom.utils.onnx_function.function_onnx_graph (name, target_opset = None, dtype = <class ‘numpy.float32’>, weight_name = None, kwargs)
Returns the ONNX graph corresponding to a function.
onnxcustom.utils.orttraining_helper.penalty_loss_onnx (name, dtype, l1 = None, l2 = None, existing_names = None)
Returns onnx nodes to compute
where
and
.
gradient¶
onnxcustom.training.grad_helper.onnx_derivative (onx, weights = None, inputs = None, options = DerivativeOptions.Zero, loss = None, label = None, path_name = None)
Builds the gradient for an onnx graph.