Coverage for mlprodict/onnx_tools/optim/graph_schema_helper.py: 78%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Functions to help guessing the final graph structure.
4"""
5import numpy
6from onnx import TensorProto
9def _guess_type(var):
10 from skl2onnx.algebra.type_helper import _guess_type as skl2onnx__guess_type # delayed
11 if isinstance(var, dict) and 'value' in var:
12 return skl2onnx__guess_type(var['value']) # pragma: no cover
13 return skl2onnx__guess_type(var)
16def get_defined_inputs(input_names, variables=None, dtype=None,
17 schema=None):
18 """
19 Retrieves defined inputs in already declared variables
20 bsed on their names.
22 @param input_names input names
23 @param variables registered variables created
24 by previous operators
25 @param dtype float computational type
26 @param schema defined inputs by schema (*expected_inputs*)
27 @return typed inputs as ``tuple(name, type)``
28 """
29 from skl2onnx.common.data_types import ( # delayed
30 DataType, FloatTensorType, DoubleTensorType)
32 def guess_type_variable(name, schema):
33 if variables is None:
34 if (schema is None or
35 not isinstance(schema, (DataType, tuple))):
36 return ( # pragma: no cover
37 DoubleTensorType() if dtype == numpy.float64 else FloatTensorType())
38 return schema if isinstance(schema, DataType) else schema[1]
39 if name in variables:
40 ty = variables[name]
41 if isinstance(ty, DataType):
42 shape = ty.shape
43 if 0 in shape:
44 raise RuntimeError( # pragma: no cover
45 "Shape cannot be empty: name='{}', var={}".format(
46 name, ty))
47 return variables[name]
48 if isinstance(ty, dict) and 'value' in ty:
49 # constant
50 arr = ty['value']
51 try:
52 return _guess_type(arr)
53 except RuntimeError as e: # pragma: no cover
54 raise RuntimeError(
55 "Unable to guess type of variable '{}' - {}."
56 "".format(name, arr)) from e
57 raise NotImplementedError( # pragma: no cover
58 "Unable to guess type for '{}' form '{}'.".format(
59 name, variables[name]))
60 if isinstance(schema, (DataType, tuple)):
61 sch = schema if isinstance(schema, DataType) else schema[1]
62 if not isinstance(sch, str):
63 return sch
64 # Inputs. Let's assume it is a vector of floats.
65 return DoubleTensorType() if dtype == numpy.float64 else FloatTensorType()
67 if schema is None or len(schema) < len(input_names):
68 inputs = [(name, guess_type_variable(name, None))
69 for name in input_names]
70 else:
71 inputs = [(name, guess_type_variable(name, schema=sch))
72 for name, sch in zip(input_names, schema)]
73 return inputs
76def get_defined_outputs(outputs, onnx_node, typed_inputs=None, variables=None,
77 dtype=None, schema=None, schema_inputs=None):
78 """
79 Gets types of predefined outputs when they cannot be inferred.
80 Some part of it should be automated based
81 on type constraints.
83 :param outputs: requested outputs
84 :param onnx_node: :epkg:`ONNX` node definition
85 :param typed_inputs: known typed inputs of the node as `tuple(name, type)`
86 :param variables: registered variables created by previous operators
87 :param dtype: float computational type
88 :param schema: defined outputs by schema (*expected_outputs*)
89 :param schema_inputs: defined inputs by schema (*expected_inputs*)
90 :return: typed outputs as ``tuple(name, type)``
91 """
92 from skl2onnx.common.data_types import ( # delayed
93 DataType,
94 FloatTensorType, SequenceType, DictionaryType,
95 Int64Type, Int64TensorType, BooleanTensorType,
96 DoubleTensorType, _guess_type_proto, _guess_type_proto_str)
98 if schema is None:
99 ft = DoubleTensorType if dtype == numpy.float64 else FloatTensorType
100 elif len(schema) != 1:
101 raise ValueError( # pragma: no cover
102 "schema should only contain one output not {}.".format(schema))
103 else:
104 if isinstance(schema, DataType):
105 ft = schema[0].__class__
106 else:
107 ft = schema[0][1].__class__
109 if onnx_node.op_type in {'ZipMap', 'ArgMin', 'ArgMax', 'Shape',
110 'Greater', 'Less', 'Equal', 'TopK',
111 'Cast', 'ArrayFeatureExtractor',
112 'Reshape', 'Transpose', 'Scan',
113 'ConstantOfShape'}:
114 if onnx_node.op_type == "ZipMap":
115 # ZipMap
116 otype = SequenceType(DictionaryType(
117 Int64Type(), ft()))
118 outputs = [(name, otype) for name in outputs]
119 elif (onnx_node.op_type in ("ArgMin", "ArgMax", 'Shape') and
120 len(outputs) == 1):
121 # ArgMin, ArgMax, Shape
122 outputs = [(outputs[0], Int64TensorType())]
123 elif (onnx_node.op_type in ("Greater", "Less", 'Equal') and
124 len(outputs) == 1):
125 # Greater, Less, Equal
126 outputs = [(outputs[0], BooleanTensorType())]
127 elif onnx_node.op_type == "TopK" and len(outputs) == 2:
128 # TopK
129 if len(typed_inputs) != 2:
130 raise RuntimeError( # pragma: no cover
131 "Wrong typed_inputs, got {}.".format(typed_inputs))
132 outputs = [(outputs[0], typed_inputs[0][1]),
133 (outputs[1], Int64TensorType())]
134 elif onnx_node.op_type == "Cast" and len(outputs) == 1:
135 # Cast
136 ttyp = _guess_type_proto(onnx_node.attribute[0].i, dims=None)
137 outputs = [(outputs[0], ttyp)]
138 elif onnx_node.op_type == "ArrayFeatureExtractor":
139 # ArrayFeatureExtractor
140 if len(typed_inputs) != 2:
141 raise RuntimeError( # pragma: no cover
142 "Wrong typed_inputs, got {}.".format(typed_inputs))
143 outputs = [(outputs[0], typed_inputs[0][1])]
144 elif onnx_node.op_type in ('Reshape', 'Transpose'):
145 # Reshape
146 outputs = [(outputs[0], typed_inputs[0][1].__class__())]
147 elif onnx_node.op_type == 'Scan':
148 # Scan
149 if len(outputs) != len(typed_inputs):
150 raise RuntimeError( # pragma: no cover
151 "Dimension mismatch, operator Scan should have "
152 "the same number of inputs and outputs {} != {}"
153 ".".format(len(outputs), len(typed_inputs)))
154 outputs = [(o, t[1].__class__())
155 for o, t in zip(outputs, typed_inputs)]
156 elif onnx_node.op_type == "ConstantOfShape":
157 # ConstantOfShape
158 outputs = [(outputs[0], ft())]
159 elif 'Classifier' in onnx_node.op_type:
160 # Good chance that's a classifier.
161 outputs = [(outputs[0], Int64TensorType()),
162 (outputs[1], ft())]
163 else:
164 if schema_inputs is not None and schema is not None:
165 dt = {}
166 for got, exp in zip(typed_inputs, schema_inputs):
167 if isinstance(exp[1], str):
168 dt[exp[1]] = got
169 out = []
170 for i in range(len(outputs)): # pylint: disable=C0200
171 o = outputs[i]
172 if isinstance(o, str):
173 exp = schema[i]
174 if exp[1] in dt:
175 out.append((o, dt[exp[1]][1].__class__()))
176 else:
177 nt = _guess_type_proto_str(exp[1], None)
178 out.append((o, nt))
179 elif (isinstance(o, tuple) and
180 (isinstance(o[1], str) or o[1] is None)):
181 exp = schema[i]
182 if exp[1] in dt:
183 out.append((o[0], dt[exp[1]][1].__class__()))
184 else:
185 nt = _guess_type_proto_str(exp[1], None)
186 out.append((o[0], nt))
187 else:
188 out.append(o)
189 outputs = out
190 elif len(typed_inputs) == 1 and len(outputs) == 1:
191 # Default case
192 # Assuming the only output is the same as the only input.
193 outputs = [(outputs[0], typed_inputs[0][1])]
194 else:
195 # Default
196 outputs = [(name, ft()) for name in outputs]
198 for name, typ in outputs:
199 if typ in ('T', None, '', 'I'):
200 raise NotImplementedError( # pragma: no cover
201 "Undefined output type: %r (outputs=%r, typed_inputs=%r, "
202 "dtype=%r, schema=%r, schema_inputs=%r, onnx_node=%r, "
203 "variables=%r)." % (
204 typ, outputs, typed_inputs, dtype,
205 schema, schema_inputs, onnx_node, variables))
206 if not isinstance(name, str):
207 raise NotImplementedError( # pragma: no cover
208 "Undefined output type: %r (outputs=%r, typed_inputs=%r, "
209 "dtype=%r, schema=%r, schema_inputs=%r, onnx_node=%r, "
210 "variables=%r)." % (
211 typ, outputs, typed_inputs, dtype,
212 schema, schema_inputs, onnx_node, variables))
213 return outputs
216def proto2vars(values):
217 """
218 Converts proto values to Variables.
219 """
220 from skl2onnx.common.data_types import ( # delayed
221 FloatTensorType, SequenceType, DictionaryType,
222 Int64Type, Int64TensorType, BooleanTensorType,
223 Int32TensorType, DoubleTensorType, FloatType,
224 StringTensorType, Float16TensorType)
226 def ptype2vttype(it, shape):
227 if it == TensorProto.FLOAT: # pylint: disable=E1101
228 return FloatTensorType(shape)
229 if it == TensorProto.DOUBLE: # pylint: disable=E1101
230 return DoubleTensorType(shape)
231 if it == TensorProto.INT64: # pylint: disable=E1101
232 return Int64TensorType(shape)
233 if it == TensorProto.INT32: # pylint: disable=E1101
234 return Int32TensorType(shape)
235 if it == TensorProto.BOOL: # pylint: disable=E1101
236 return BooleanTensorType(shape)
237 if it == TensorProto.STRING: # pylint: disable=E1101
238 return StringTensorType(shape)
239 if Float16TensorType is None:
240 if it == TensorProto.FLOAT16: # pylint: disable=E1101
241 return Float16TensorType(shape)
242 raise NotImplementedError( # pragma: no cover
243 "Unrecognized proto type {} with shape {}".format(it, shape))
245 def ptype2vtype(it):
246 if it == TensorProto.FLOAT: # pylint: disable=E1101
247 return FloatType()
248 if it == TensorProto.INT64: # pylint: disable=E1101
249 return Int64Type()
250 raise NotImplementedError( # pragma: no cover
251 "Unrecognized proto type {}".format(it))
253 res = []
254 for v_ in values:
255 v = v_
256 name = v.name if hasattr(v, 'name') else None
257 if hasattr(v, 'type') and str(v.type) != '':
258 t = v.type
259 v = proto2vars([t])[0][1]
260 elif hasattr(v, 'sequence_type') and str(v.sequence_type) != '':
261 subtype = proto2vars([v.sequence_type.elem_type])[0][1]
262 v = SequenceType(subtype)
263 elif hasattr(v, 'tensor_type') and str(v.tensor_type) != '':
264 tt = v.tensor_type
265 el = tt.elem_type
266 shape = tt.shape
267 dim = shape.dim
268 if len(dim) == 0:
269 shape = []
270 else:
271 shape = [dim[i].dim_value for i in range(len(dim))]
272 v = ptype2vttype(el, shape)
273 elif hasattr(v, 'map_type') and str(v.map_type) != '':
274 mt = v.map_type
275 keyt = ptype2vtype(mt.key_type)
276 valt = proto2vars([mt.value_type])[0][1]
277 v = DictionaryType(keyt, valt)
278 else:
279 raise RuntimeError( # pragma: no cover
280 "Unable to build a variable from {}.".format(v))
281 if v.shape is not None and 0 in v.shape:
282 # Replaces 0 by None
283 new_shape = tuple(None if d == 0 else d for d in v.shape)
284 if new_shape in ((None, ), None):
285 v = v.__class__()
286 else:
287 v = v.__class__(new_shape)
288 if v.shape is not None and 0 in v.shape:
289 raise RuntimeError( # pragma: no cover
290 "Shape cannot be empty: '{}': {}.".format(
291 name, v_))
292 res.append((name, v))
293 return res