Coverage for mlprodict/onnx_tools/exports/skl2onnx_helper.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Helpers to run examples created with :epkg:`sklearn-onnx`.
4"""
5from onnx import helper, TensorProto
8def _copy_inout(inout, scope, new_name):
9 shape = [s.dim_value for s in inout.type.tensor_type.shape.dim]
10 value_info = helper.make_tensor_value_info(
11 new_name, inout.type.tensor_type.elem_type, shape)
12 return value_info
15def _clean_variable_name(name, scope):
16 return scope.get_unique_variable_name(name)
19def _clean_operator_name(name, scope):
20 return scope.get_unique_operator_name(name)
23def _clean_initializer_name(name, scope):
24 return scope.get_unique_variable_name(name)
27def add_onnx_graph(scope, operator, container, onx):
28 """
29 Adds a whole ONNX graph to an existing one following
30 :epkg:`skl2onnx` API assuming this ONNX graph implements
31 an `operator <http://onnx.ai/sklearn-onnx/api_summary.html?
32 highlight=operator#skl2onnx.common._topology.Operator>`_.
34 :param scope: scope (to get unique names)
35 :param operator: operator
36 :param container: container
37 :param onx: ONNX graph
38 """
39 graph = onx.graph
40 name_mapping = {}
41 node_mapping = {}
42 for node in graph.node:
43 name = node.name
44 if name is not None:
45 node_mapping[node.name] = _clean_initializer_name(
46 node.name, scope)
47 for o in node.input:
48 name_mapping[o] = _clean_variable_name(o, scope)
49 for o in node.output:
50 name_mapping[o] = _clean_variable_name(o, scope)
51 for o in graph.initializer:
52 name_mapping[o.name] = _clean_operator_name(o.name, scope)
54 inputs = [_copy_inout(o, scope, name_mapping[o.name])
55 for o in graph.input]
56 outputs = [_copy_inout(o, scope, name_mapping[o.name])
57 for o in graph.output]
59 for inp, to in zip(operator.inputs, inputs):
60 n = helper.make_node('Identity', [inp.onnx_name], [to.name],
61 name=_clean_operator_name('Identity', scope))
62 container.nodes.append(n)
64 for inp, to in zip(outputs, operator.outputs):
65 n = helper.make_node('Identity', [inp.name], [to.onnx_name],
66 name=_clean_operator_name('Identity', scope))
67 container.nodes.append(n)
69 for node in graph.node:
70 n = helper.make_node(
71 node.op_type,
72 [name_mapping[o] for o in node.input],
73 [name_mapping[o] for o in node.output],
74 name=node_mapping[node.name] if node.name else None,
75 domain=node.domain if node.domain else None)
76 n.attribute.extend(node.attribute) # pylint: disable=E1101
77 container.nodes.append(n)
79 for o in graph.initializer:
80 as_str = o.SerializeToString()
81 tensor = TensorProto()
82 tensor.ParseFromString(as_str)
83 tensor.name = name_mapping[o.name]
84 container.initializers.append(tensor)
86 # opset
87 for oimp in onx.opset_import:
88 container.node_domain_version_pair_sets.add(
89 (oimp.domain, oimp.version))