Coverage for mlprodict/npy/onnx_numpy_annotation.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief :epkg:`numpy` annotations.
5.. versionadded:: 0.6
6"""
7import inspect
8from collections import OrderedDict
9from typing import TypeVar, Generic
10import numpy
11from .onnx_version import FctVersion
13try:
14 numpy_bool = numpy.bool_
15except AttributeError: # pragma: no cover
16 numpy_bool = bool
18try:
19 numpy_str = numpy.str_
20except AttributeError: # pragma: no cover
21 numpy_str = str
23Shape = TypeVar("Shape")
24DType = TypeVar("DType")
27all_dtypes = (numpy.float32, numpy.float64,
28 numpy.int32, numpy.int64,
29 numpy.uint32, numpy.uint64)
32def get_args_kwargs(fct, n_optional):
33 """
34 Extracts arguments and optional parameters of a function.
36 :param fct: function
37 :param n_optional: number of arguments to consider as
38 optional arguments and not parameters, this parameter skips
39 the first *n_optional* paramerters
40 :return: arguments, OrderedDict
42 Any optional argument ending with '_' is ignored.
43 """
44 params = inspect.signature(fct).parameters
45 if n_optional == 0:
46 items = list(params.items())
47 args = [name for name, p in params.items()
48 if p.default == inspect.Parameter.empty]
49 else:
50 items = []
51 args = []
52 for name, p in params.items():
53 if p.default == inspect.Parameter.empty:
54 args.append(name)
55 else:
56 if n_optional > 0:
57 args.append(name)
58 n_optional -= 1
59 else:
60 items.append((name, p))
62 kwargs = OrderedDict((name, p.default) for name, p in items
63 if (p.default != inspect.Parameter.empty and
64 name != 'op_version'))
65 if args[0] == 'self':
66 args = args[1:]
67 kwargs['op_'] = None
68 return args, kwargs
71class NDArray(numpy.ndarray, Generic[Shape, DType]):
72 """
73 Used to annotation ONNX numpy functions.
75 .. versionadded:: 0.6
76 """
77 class ShapeType:
78 "Stores shape information."
80 def __init__(self, params):
81 self.__args__ = params
83 def __class_getitem__(cls, params): # pylint: disable=W0221,W0237
84 "Overwrites this method."
85 if not isinstance(params, tuple):
86 params = (params,) # pragma: no cover
87 return NDArray.ShapeType(params)
90class _NDArrayAlias:
91 """
92 Ancestor to custom signature.
94 :param dtypes: input dtypes
95 :param dtypes_out: output dtypes
96 :param n_optional: number of optional parameters, 0 by default
97 :param nvars: True if the function allows an infinite number of inputs,
98 this is incompatible with parameter *n_optional*.
100 *dtypes*, *dtypes_out* by default are a tuple of tuple:
102 * first dimension: type of every input
103 * second dimension: list of types for one input
105 .. versionadded:: 0.6
106 """
108 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None,
109 nvars=False):
110 "constructor"
111 if dtypes is None:
112 raise ValueError("dtypes cannot be None.") # pragma: no cover
113 if isinstance(dtypes, tuple) and len(dtypes) == 0:
114 raise TypeError("dtypes must not be empty.") # pragma: no cover
115 if isinstance(dtypes, tuple) and not isinstance(dtypes[0], tuple):
116 dtypes = tuple(t if isinstance(t, str) else (t,) for t in dtypes)
117 if isinstance(dtypes, str) and '_' in dtypes:
118 dtypes, dtypes_out = dtypes.split('_')
119 if not isinstance(dtypes, (tuple, list)):
120 dtypes = (dtypes, )
122 self.mapped_types = {}
123 self.dtypes = _NDArrayAlias._process_type(
124 dtypes, self.mapped_types, 0)
125 if dtypes_out is None:
126 self.dtypes_out = (self.dtypes[0], )
127 elif isinstance(dtypes_out, int):
128 self.dtypes_out = (self.dtypes[dtypes_out], )
129 else:
130 if not isinstance(dtypes_out, (tuple, list)):
131 dtypes_out = (dtypes_out, )
132 self.dtypes_out = _NDArrayAlias._process_type(
133 dtypes_out, self.mapped_types, 0)
134 self.n_optional = 0 if n_optional is None else n_optional
135 self.n_variables = nvars
137 if not isinstance(self.dtypes, tuple):
138 raise TypeError( # pragma: no cover
139 "self.dtypes must be a tuple not {}.".format(self.dtypes))
140 if (len(self.dtypes) == 0 or
141 not isinstance(self.dtypes[0], tuple)):
142 raise TypeError( # pragma: no cover
143 "Type mismatch in self.dtypes: {}.".format(self.dtypes))
144 if (len(self.dtypes[0]) == 0 or
145 isinstance(self.dtypes[0][0], tuple)):
146 raise TypeError( # pragma: no cover
147 "Type mismatch in self.dtypes: {}.".format(self.dtypes))
149 if not isinstance(self.dtypes_out, tuple):
150 raise TypeError( # pragma: no cover
151 "self.dtypes_out must be a tuple not {}.".format(self.dtypes_out))
152 if (len(self.dtypes_out) == 0 or
153 not isinstance(self.dtypes_out[0], tuple)):
154 raise TypeError( # pragma: no cover
155 "Type mismatch in self.dtypes_out={}, "
156 "self.dtypes={}.".format(self.dtypes_out, self.dtypes))
157 if (len(self.dtypes_out[0]) == 0 or
158 isinstance(self.dtypes_out[0][0], tuple)):
159 raise TypeError( # pragma: no cover
160 "Type mismatch in self.dtypes_out: {}.".format(self.dtypes_out))
162 if self.n_variables and self.n_optional > 0:
163 raise RuntimeError( # pragma: no cover
164 "n_variables and n_optional cannot be positive at "
165 "the same type.")
167 @staticmethod
168 def _process_type(dtypes, mapped_types, index):
169 """
170 Nicknames such as `floats`, `int`, `ints`, `all`
171 can be used to describe multiple inputs for
172 a signature. This function intreprets that.
174 .. runpython::
175 :showcode:
177 from mlprodict.npy.onnx_numpy_annotation import _NDArrayAlias
178 for name in ['all', 'int', 'ints', 'floats', 'T']:
179 print(name, _NDArrayAlias._process_type(name, {'T': 0}, 0))
180 """
181 if isinstance(dtypes, str):
182 if ":" in dtypes:
183 name, dtypes = dtypes.split(':')
184 if name in mapped_types and dtypes != mapped_types[name]:
185 raise RuntimeError( # pragma: no cover
186 "Type name mismatch for '%s:%s' in %r." % (
187 name, dtypes, list(sorted(mapped_types))))
188 mapped_types[name] = (dtypes, index)
189 if dtypes == "all":
190 dtypes = all_dtypes
191 elif dtypes in ("int", "int64"):
192 dtypes = (numpy.int64, )
193 elif dtypes == "bool":
194 dtypes = (numpy_bool, )
195 elif dtypes == "floats":
196 dtypes = (numpy.float32, numpy.float64)
197 elif dtypes == "ints":
198 dtypes = (numpy.int32, numpy.int64)
199 elif dtypes == "int64":
200 dtypes = (numpy.int64, )
201 elif dtypes == "float32":
202 dtypes = (numpy.float32, )
203 elif dtypes == "float64":
204 dtypes = (numpy.float64, )
205 elif dtypes not in mapped_types:
206 raise ValueError( # pragma: no cover
207 "Unexpected shortcut for dtype %r." % dtypes)
208 elif not isinstance(dtypes, tuple):
209 dtypes = (dtypes, )
210 return dtypes
212 if isinstance(dtypes, (tuple, list)):
213 insig = [_NDArrayAlias._process_type(dt, mapped_types, index + d)
214 for d, dt in enumerate(dtypes)]
215 return tuple(insig)
217 if dtypes in all_dtypes:
218 return dtypes
220 raise NotImplementedError( # pragma: no cover
221 "Unexpected input dtype %r." % dtypes)
223 def __repr__(self):
224 "usual"
225 return "%s(%r, %r, %r)" % (
226 self.__class__.__name__, self.dtypes, self.dtypes_out,
227 self.n_optional)
229 def _get_output_types(self, key):
230 """
231 Tries to infer output types.
232 """
233 res = []
234 for i, o in enumerate(self.dtypes_out):
235 if not isinstance(o, tuple):
236 raise TypeError( # pragma: no cover
237 "All outputs must be tuple, output %d is %r."
238 "" % (i, o))
239 if (len(o) == 1 and (o[0] in all_dtypes or
240 o[0] in (bool, numpy_bool, str, numpy_str))):
241 res.append(o[0])
242 elif len(o) == 1 and o[0] in self.mapped_types:
243 info = self.mapped_types[o[0]]
244 res.append(key[info[1]])
245 elif key[0] in o:
246 res.append(key[0])
247 else:
248 raise RuntimeError( # pragma: no cover
249 "Unable to guess output type for output %d, "
250 "input types are %r, expected output is %r."
251 "" % (i, key, o))
252 return tuple(res)
254 def get_inputs_outputs(self, args, kwargs, version):
255 """
256 Returns the list of inputs, outputs.
258 :param args: list of arguments
259 :param kwargs: list of optional arguments
260 :param version: required version
261 :return: *tuple(inputs, kwargs, outputs, optional)*,
262 inputs and outputs are tuple, kwargs are the arguments,
263 *optional* is the number of optional arguments
264 """
265 if not isinstance(version, FctVersion):
266 raise TypeError("Version must be of type 'FctVersion' not "
267 "%s, version=%s." % (type(version), version))
268 if args == ['args', 'kwargs']:
269 raise RuntimeError( # pragma: no cover
270 "Issue with signature %r." % args)
271 for k, v in kwargs.items():
272 if isinstance(v, type):
273 raise RuntimeError( # pragma: no cover
274 "Default value for argument %r must not be of type %r"
275 "." % (k, v))
276 if (not self.n_variables and
277 len(args) > len(self.dtypes)):
278 raise RuntimeError(
279 "Unexpected number of inputs version=%s.\n"
280 "Given: args=%s dtypes=%s." % (
281 version, args, self.dtypes))
283 def _possible_names():
284 yield 'y'
285 yield 'z' # pragma: no cover
286 yield 'o' # pragma: no cover
287 for i in range(0, 10000): # pragma: no cover
288 yield 'o%d' % i
290 new_kwargs = OrderedDict(
291 (k, v) for k, v in zip(kwargs, version.kwargs or tuple()))
292 if self.n_variables:
293 # undefined number of inputs
294 optional = 0
295 else:
296 optional = len(self.dtypes) - len(version.args)
297 if optional > self.n_optional:
298 raise RuntimeError( # pragma: no cover
299 "Unexpected number of optional parameters %d, at most "
300 "%d are expected, version=%s, args=%s, dtypes=%s." % (
301 optional, self.n_optional, version, args, self.dtypes))
302 optional = self.n_optional - optional
304 onnx_types = [k for k in version.args]
305 inputs = list(zip(args[:len(version.args)], onnx_types))
306 if self.n_variables and len(inputs) < len(version.args):
307 # Complete the list of inputs
308 last_name = inputs[-1][0]
309 while len(inputs) < len(onnx_types):
310 inputs.append(('%s%d' % (last_name, len(inputs)),
311 onnx_types[len(inputs)]))
313 key_out = self._get_output_types(version.args)
314 onnx_types_out = key_out
316 names_out = []
317 names_in = set(inp[0] for inp in inputs)
318 for _ in key_out:
319 for name in _possible_names():
320 if name not in names_in and name not in names_out:
321 name_out = name
322 break
323 names_out.append(name_out)
324 names_in.add(name_out)
326 outputs = list(zip(names_out, onnx_types_out))
327 if optional < 0:
328 raise RuntimeError( # pragma: no cover
329 "optional cannot be negative %r (self.n_optional=%r, "
330 "len(self.dtypes)=%r, len(inputs)=%r) "
331 "names_in=%r, names_out=%r." % (
332 optional, self.n_optional, len(self.dtypes),
333 len(inputs), names_in, names_out))
335 if (not self.n_variables and
336 len(inputs) + len(new_kwargs) > len(version)):
337 raise RuntimeError( # pragma: no cover
338 "Mismatch number of inputs and arguments for version=%s.\n"
339 "Given: args=%s kwargs=%s.\n"
340 "Returned: inputs=%s new_kwargs=%s.\n" % (
341 version, args, kwargs, inputs, new_kwargs))
342 if not self.n_variables and len(inputs) > len(self.dtypes):
343 raise RuntimeError( # pragma: no cover
344 "Mismatch number of inputs for version=%s.\n"
345 "Given: args=%s.\n"
346 "Expected: dtypes=%s\n"
347 "Returned: inputs=%s.\n" % (
348 version, args, self.dtypes, inputs))
350 return inputs, kwargs, outputs, optional, self.n_variables
352 def shape_calculator(self, dims):
353 """
354 Returns expected dimensions given the input dimensions.
355 """
356 if len(dims) == 0:
357 return None
358 res = [dims[0]]
359 for _ in dims[1:]:
360 res.append(None)
361 return res
364class NDArrayType(_NDArrayAlias):
365 """
366 Shortcut to simplify signature description.
368 :param dtypes: input dtypes
369 :param dtypes_out: output dtypes
370 :param n_optional: number of optional parameters, 0 by default
371 :param nvars: True if the function allows an infinite number of inputs,
372 this is incompatible with parameter *n_optional*.
374 .. versionadded:: 0.6
375 """
377 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None, nvars=False):
378 _NDArrayAlias.__init__(self, dtypes=dtypes, dtypes_out=dtypes_out,
379 n_optional=n_optional, nvars=nvars)
382class NDArrayTypeSameShape(NDArrayType):
383 """
384 Shortcut to simplify signature description.
386 :param dtypes: input dtypes
387 :param dtypes_out: output dtypes
388 :param n_optional: number of optional parameters, 0 by default
389 :param nvars: True if the function allows an infinite number of inputs,
390 this is incompatible with parameter *n_optional*.
392 .. versionadded:: 0.6
393 """
395 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None, nvars=False):
396 NDArrayType.__init__(self, dtypes=dtypes, dtypes_out=dtypes_out,
397 n_optional=n_optional, nvars=nvars)
400class NDArraySameType(NDArrayType):
401 """
402 Shortcut to simplify signature description.
404 :param dtypes: input dtypes
406 .. versionadded:: 0.6
407 """
409 def __init__(self, dtypes=None):
410 if dtypes is None:
411 raise ValueError("dtypes cannot be None.") # pragma: no cover
412 if isinstance(dtypes, str) and "_" in dtypes:
413 raise ValueError( # pragma: no cover
414 "dtypes cannot include '_' meaning two different types.")
415 if isinstance(dtypes, tuple):
416 raise ValueError( # pragma: no cover
417 "dtypes must be a single type.")
418 NDArrayType.__init__(self, dtypes=(dtypes, ))
420 def __repr__(self):
421 "usual"
422 return "%s(%r)" % (
423 self.__class__.__name__, self.dtypes)
426class NDArraySameTypeSameShape(NDArraySameType):
427 """
428 Shortcut to simplify signature description.
430 :param dtypes: input dtypes
432 .. versionadded:: 0.6
433 """
435 def __init__(self, dtypes=None):
436 NDArraySameType.__init__(self, dtypes=dtypes)