Coverage for mlprodict/npy/xop_variable.py: 94%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Xop API to build onnx graphs. Inspired from :epkg:`sklearn-onnx`.
5.. versionadded:: 0.9
6"""
7import numpy
8from onnx import ValueInfoProto
9from onnx.helper import make_tensor_type_proto
10from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
11from onnx.defs import onnx_opset_version
12from .. import __max_supported_opset__
15def max_supported_opset():
16 """
17 Returns the latest supported opset for the main domain.
19 .. runpython::
20 :showcode:
22 from mlprodict.npy.xop_variable import max_supported_opset
23 print("max_supported_opset() returns", max_supported_opset())
24 """
25 return min(__max_supported_opset__, onnx_opset_version())
28def is_numpy_dtype(dtype):
29 """
30 Tells if a dtype is a numpy dtype.
32 :param dtype: anything
33 :return: boolean
34 """
35 if isinstance(dtype, (list, dict, Variable)):
36 return False
37 if dtype in NP_TYPE_TO_TENSOR_TYPE:
38 return True
39 dt = numpy.dtype(dtype)
40 if dt in NP_TYPE_TO_TENSOR_TYPE:
41 return True
42 return False
45def numpy_type_prototype(dtype):
46 """
47 Converts a numpy dtyp into a TensorProto dtype.
49 :param dtype: dtype
50 :return: proto dtype
51 """
52 if dtype in NP_TYPE_TO_TENSOR_TYPE:
53 return NP_TYPE_TO_TENSOR_TYPE[dtype]
54 dt = numpy.dtype(dtype)
55 if dt in NP_TYPE_TO_TENSOR_TYPE:
56 return NP_TYPE_TO_TENSOR_TYPE[dt]
57 raise ValueError( # pragma: no cover
58 "Unable to convert dtype %r into ProtoType." % dtype)
61def guess_numpy_type(data_type):
62 """
63 Guesses the corresponding numpy type based on data_type.
64 """
65 if data_type in (numpy.float64, numpy.float32, numpy.int8, numpy.uint8,
66 numpy.str_, numpy.bool_, numpy.int32, numpy.int64):
67 return data_type
68 if data_type == str:
69 return numpy.str_
70 if data_type == bool:
71 return numpy.bool_
72 name2numpy = {
73 'FloatTensorType': numpy.float32,
74 'DoubleTensorType': numpy.float64,
75 'Int32TensorType': numpy.int32,
76 'Int64TensorType': numpy.int64,
77 'StringTensorType': numpy.str_,
78 'BooleanTensorType': numpy.bool_,
79 'Complex64TensorType': numpy.complex64,
80 'Complex128TensorType': numpy.complex128,
81 }
82 cl_name = data_type.__class__.__name__
83 if cl_name in name2numpy:
84 return name2numpy[cl_name]
85 if hasattr(data_type, 'type'):
86 return guess_numpy_type(data_type.type)
87 raise NotImplementedError( # pragma: no cover
88 "Unsupported data_type '{}'.".format(data_type))
91class Variable:
92 """
93 An input or output to an ONNX graph.
95 :param name: name
96 :param dtype: :epkg:`numpy` dtype (can be None)
97 :param shape: shape (can be None)
98 :param added_dtype: :epkg:`numpy` dtype specified at conversion type
99 (can be None)
100 :param added_shape: :epkg:`numpy` shape specified at conversion type
101 (can be None)
102 """
104 def __init__(self, name, dtype=None, shape=None, added_dtype=None,
105 added_shape=None):
106 if (dtype is not None and isinstance(
107 dtype, (int, Variable, tuple, numpy.ndarray))):
108 raise TypeError(
109 "Unexpected type %r for dtype." % type(dtype))
110 if (added_dtype is not None and isinstance(
111 added_dtype, (int, Variable, tuple, numpy.ndarray))):
112 raise TypeError(
113 "Unexpected type %r for added_dtype." % type(added_dtype))
114 if shape is not None and not isinstance(shape, (tuple, list)):
115 raise TypeError(
116 "Unexpected type %r for shape." % type(shape))
117 if (added_shape is not None and not isinstance(
118 added_shape, (tuple, list))):
119 raise TypeError(
120 "Unexpected type %r for added_shape." % type(added_shape))
122 if isinstance(name, Variable):
123 if (dtype is not None or shape is not None or
124 added_dtype is not None or added_shape is not None):
125 raise ValueError( # pragma: no cover
126 "If name is a Variable, then all others attributes "
127 "should be None.")
129 self.name_ = name.name_
130 self.dtype_ = name.dtype_
131 self.added_dtype_ = name.added_dtype_
132 self.shape_ = name.shape_
133 self.added_shape_ = name.added_shape_
134 else:
135 if not isinstance(name, str):
136 raise TypeError( # pragma: no cover
137 "name must be a string not %r." % type(name))
139 self.name_ = name
140 self.dtype_ = dtype
141 self.added_dtype_ = added_dtype
142 self.shape_ = shape
143 self.added_shape_ = added_shape
145 def to_skl2onnx(self, scope=None):
146 """
147 Converts this instance into an instance of *Variable*
148 from :epkg:`sklearn-onnx`.
149 """
150 from skl2onnx.common._topology import Variable as skl2onnxVariable # delayed
151 from skl2onnx.common.data_types import _guess_numpy_type # delayed
152 inst = _guess_numpy_type(self.dtype, self.shape)
153 var = skl2onnxVariable(self.name, self.name, type=inst, scope=scope)
154 return var
156 @staticmethod
157 def from_skl2onnx(var):
158 """
159 Converts var from :epkg:`sklearn-onnx` into this class.
160 """
161 return Variable(var.onnx_name, guess_numpy_type(var.type),
162 shape=var.type.shape)
164 @property
165 def name(self):
166 "Returns the variable name (`self.name_`)."
167 return self.name_
169 @property
170 def dtype(self):
171 "Returns `self.dtype_`."
172 return self.dtype_
174 @property
175 def shape(self):
176 "Returns `self.shape_`."
177 return self.shape_
179 @property
180 def proto_type(self):
181 "Returns the proto type for `self.dtype_`."
182 if self.dtype_ is None:
183 return 0
184 return numpy_type_prototype(self.dtype_)
186 @property
187 def proto_added_type(self):
188 "Returns the proto type for `self.added_dtype_` or `self.dtype_`."
189 dt = self.added_dtype_ or self.dtype_
190 if dt is None:
191 return 0
192 return numpy_type_prototype(dt)
194 @property
195 def proto_added_shape(self):
196 "Returns the shape for `self.added_shape_` or `self.shape`."
197 dt = self.added_shape_ or self.shape_
198 if dt is None:
199 return None
200 return list(dt)
202 def __repr__(self):
203 "usual"
204 kwargs = dict(dtype=self.dtype_, shape=self.shape_,
205 added_dtype=self.added_dtype_,
206 added_shape=self.added_shape_)
207 kwargs = {k: v for k, v in kwargs.items() if v is not None}
208 if len(kwargs) > 0:
209 msg = ", " + ", ".join("%s=%r" % (k, v) for k, v in kwargs.items())
210 else:
211 msg = ''
212 return "%s(%r%s)" % (
213 self.__class__.__name__, self.name_, msg)
215 def is_named(self, name):
216 "Tells the variable is named like that."
217 if not isinstance(name, str):
218 raise TypeError( # pragma: no cover
219 "name is expected to be a string not %r." % type(name))
220 return self.name == name
222 def copy_add(self, dtype):
223 """
224 Returns a copy of this variable with a new dtype.
226 :param dtype: added type
227 :return: @see cl Variable
228 """
229 if self.added_dtype_ is not None:
230 raise RuntimeError( # pragma: no cover
231 "Cannot copy as added_dtype is not None.")
232 if isinstance(dtype, numpy.ndarray):
233 dtype, shape = dtype.dtype, dtype.shape
234 else:
235 shape = None
236 return Variable(self.name_, self.dtype_, self.shape_, dtype, shape)
238 def copy_merge(self, var):
239 """
240 Merges information from both Variable.
241 """
242 if not isinstance(var, Variable):
243 return self.copy_add(var)
244 res = Variable(self.name_, self.dtype_,
245 self.shape_, self.added_dtype_,
246 self.added_shape_)
247 if self.added_dtype_ is None and var.dtype_ is not None:
248 res.added_dtype_ = var.dtype_
249 if self.added_shape_ is None and var.shape_ is not None:
250 res.added_shape_ = var.shape_
251 return res
253 def copy_name(self, name):
254 """
255 Returns a copy with a new name.
256 """
257 return Variable(
258 name or self.name_, self.dtype_,
259 self.shape_, self.added_dtype_,
260 self.added_shape_)
262 def __eq__(self, other):
263 """
264 Compares every attributes.
265 """
266 if not isinstance(other, Variable):
267 raise TypeError(
268 "Unexpected type %r." % type(other))
269 if self.name != other.name:
270 return False
271 if self.shape_ != other.shape_:
272 return False
273 if self.dtype_ != other.dtype_:
274 return False
275 return True
277 def make_value_info(self):
278 """
279 Converts the variable into `onnx.ValueInfoProto`.
281 :return: instance of `onnx.ValueInfoProto`
282 """
283 value_info = ValueInfoProto()
284 value_info.name = self.name
285 tensor_type_proto = make_tensor_type_proto(self.proto_type, self.shape)
286 value_info.type.CopyFrom(tensor_type_proto) # pylint: disable=E1101
287 return value_info
289 @staticmethod
290 def from_pb(obj):
291 """
292 Creates a Variable from a protobuf object.
294 :param obj: initializer, tensor
295 :return: @see cl Variable
296 """
297 from ..onnx_tools.onnx2py_helper import from_pb
298 name, ty, shape = from_pb(obj)
299 return Variable(name, ty, shape=shape)
302class NodeResultName:
303 """
304 Defines a result name for a node.
306 :param node: node it comes from
307 :param index: index of the output
308 """
310 def __init__(self, node, index):
311 self.node = node
312 self.index = index
314 def __repr__(self):
315 "Usual"
316 return "%s(%r, %r)" % (self.__class__.__name__, self.node, self.index)
318 def get_name(self):
319 """
320 Returns a name from output_names or a suggestion for a name.
321 """
322 if self.node is None:
323 raise RuntimeError( # pragma: no cover
324 "node must not be None.")
325 if self.node.output_names is not None:
326 return self.node.output_names[self.index].name
327 cl = self.node.op_type.lower()[:3]
328 return "out_%s_%d" % (cl, self.index)
331class DetectedVariable:
332 """
333 Wrapper around a @see cl Variable to detect inputs
334 and outputs of a graph.
336 :param node: node where the variable was detected
337 :param var: instance of @see cl Variable
338 :param index: index, only used if it is an output
339 """
341 def __init__(self, node, var, index):
342 if not isinstance(var, Variable):
343 raise TypeError( # pragma: no cover
344 "Unexpected type %r, it should be a Variable."
345 "" % type(var))
346 self.node = node
347 self.var = var
348 self.index = index
350 @property
351 def name(self):
352 "Returns variable name."
353 return self.var.name
355 def __repr__(self):
356 "usual"
357 sindex = ", %s" % self.index if self.index >= 0 else ""
358 if self.node is None:
359 return "%s(None, %r%s)" % (
360 self.__class__.__name__, self.var, sindex)
361 return "%s(%s-%d, %r%s)" % (
362 self.__class__.__name__, self.node.__class__.__name__,
363 id(self.node), self.var, sindex)
366class InputDetectedVariable(DetectedVariable):
367 """
368 Instance of @see cl DetectedVariable.
369 Only for inputs.
370 """
372 def __init__(self, node, var):
373 DetectedVariable.__init__(self, node, var, -1)
376class OutputDetectedVariable(DetectedVariable):
377 """
378 Instance of @see cl DetectedVariable.
379 Only for outputs.
380 """
381 pass