Coverage for mlprodict/onnx_conv/sklconv/function_transformer_converters.py: 86%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Rewrites some of the converters implemented in
4:epkg:`sklearn-onnx`.
5"""
6import copy
7from onnx.helper import make_tensor
8from skl2onnx.common.data_types import guess_numpy_type
9from skl2onnx.common._apply_operation import apply_concat, apply_identity
10from ...onnx_tools.onnx2py_helper import _var_as_dict, guess_proto_dtype
11from ...npy.onnx_version import FctVersion
14def new_calculate_sklearn_function_transformer_output_shapes(operator):
15 """
16 Rewrites the converters implemented in
17 :epkg:`sklearn-onnx` to support custom functions
18 implemented with :ref:`l-numpy-onnxpy`.
19 """
20 fct = operator.raw_operator.func
21 if hasattr(fct, 'signed_compiled'):
22 dtype = guess_numpy_type(operator.inputs[0].type)
23 fct = fct[FctVersion((dtype, ), None)]
24 if hasattr(fct, 'compiled'):
25 compiled = fct.compiled
26 if not hasattr(compiled, 'onnx_'):
27 raise RuntimeError( # pragma: no cover
28 "Attribute 'onnx_' is missing, function was not "
29 "converted to onnx.")
30 onx = compiled.onnx_
31 graph = onx.graph
32 outputs = graph.output
34 # Let's assume there is only one output
35 # with the same type as the input.
36 # Only the shape changes.
37 if len(outputs) != 1:
38 raise RuntimeError( # pragma: no cover
39 "Only one output is allowed not %d." % len(outputs))
40 input_type = operator.inputs[0].type.__class__
41 if compiled.meta_.get('signature', None):
42 dims = compiled.meta_['signature'].shape_calculator(
43 operator.inputs[0].type.shape)
44 else:
45 N = operator.inputs[0].type.shape[0]
46 dims = [N]
47 out = outputs[0]
48 try:
49 extra_dims = out.type.tensor_type.shape.dim
50 except AttributeError: # pragma: no cover
51 extra_dims = None
52 if extra_dims is not None:
53 val = [d.dim_value if d.dim_value > 0 else None
54 for d in extra_dims[1:]]
55 dims.extend(val)
56 operator.outputs[0].type = input_type(dims)
57 return
59 if operator.raw_operator.func is not None:
60 raise TypeError("FunctionTransformer is not supported unless the "
61 "transform function is of type %r "
62 "wrapped with onnxnumpy." % type(
63 operator.raw_operator.func))
64 N = operator.inputs[0].type.shape[0]
65 C = 0
66 for variable in operator.inputs:
67 if variable.type.shape[1] is not None:
68 C += variable.type.shape[1]
69 else:
70 C = None
71 break
73 operator.outputs[0].type = operator.inputs[0].type.__class__([N, C])
76def _copy_attributes(att):
77 if hasattr(att, 'value'):
78 return att.value
79 vt = _var_as_dict(att)
80 if vt['type']['kind'] == 'tensor':
81 value = vt['value']
82 return make_tensor(att.name, guess_proto_dtype(value.dtype),
83 value.shape, value.ravel().tolist())
84 if vt['type']['kind'] == 'real':
85 return vt['value']
86 raise RuntimeError( # pragma: no cover
87 "Unable to copy attribute %r, got %r." % (att, vt))
90def new_convert_sklearn_function_transformer(scope, operator, container):
91 """
92 Rewrites the converters implemented in
93 :epkg:`sklearn-onnx` to support custom functions
94 implemented with :ref:`l-numpy-onnxpy`.
95 """
96 op = operator.raw_operator
97 fct = op.func
98 if hasattr(fct, 'signed_compiled'):
99 dtype = guess_numpy_type(operator.inputs[0].type)
100 fct = fct[FctVersion((dtype, ), None)]
101 if hasattr(fct, 'compiled'):
102 compiled = fct.compiled
103 if not hasattr(compiled, 'onnx_'):
104 raise RuntimeError( # pragma: no cover
105 "Attribute 'onnx_' is missing, function was not "
106 "converted to onnx.")
107 onx = compiled.onnx_
108 graph = onx.graph
109 nodes = graph.node
111 # renaming all intermediate variables
112 names = []
113 for node in nodes:
114 for name in node.input:
115 names.append(name)
116 for name in node.output:
117 names.append(name)
118 names = set(names)
119 names_mapping = {}
120 for name in names:
121 names_mapping[name] = scope.get_unique_variable_name(
122 'ft_%s' % name)
124 # adding identities
125 apply_identity(scope, operator.inputs[0].full_name,
126 names_mapping[graph.input[0].name], container)
127 apply_identity(scope, names_mapping[graph.output[0].name],
128 operator.outputs[0].full_name, container)
130 # adding initializers
131 for init in graph.initializer:
132 init = copy.deepcopy(init)
133 name = names_mapping[init.name]
134 init.name = name
135 content = init.SerializeToString()
136 container.initializers_strings[content] = name
137 container.initializers.append(init)
139 # adding nodes
140 for node in nodes:
141 atts = {}
142 for att in node.attribute:
143 atts[att.name] = _copy_attributes(att)
144 container.add_node(
145 node.op_type,
146 [names_mapping[n] for n in node.input],
147 [names_mapping[n] for n in node.output],
148 name=scope.get_unique_operator_name('ft_%s' % node.op_type),
149 **atts)
150 return
152 if op.func is not None:
153 raise TypeError( # pragma: no cover
154 "FunctionTransformer is not supported unless the "
155 "transform function is of type %r or "
156 "wrapped with onnxnumpy." % type(op.func))
157 if len(operator.inputs) == 1:
158 apply_identity(scope, operator.inputs[0].full_name,
159 operator.outputs[0].full_name, container)
160 else:
161 apply_concat(scope, [i.full_name for i in operator.inputs],
162 operator.outputs[0].full_name, container)