Coverage for mlprodict/npy/onnx_numpy_wrapper.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Wraps :epkg:`numpy` functions into :epkg:`onnx`.
5.. versionadded:: 0.6
6"""
7import warnings
8from .onnx_version import FctVersion
9from .onnx_numpy_annotation import get_args_kwargs
10from .onnx_numpy_compiler import OnnxNumpyCompiler
13class _created_classes:
14 """
15 Class to store all dynamic classes created by wrappers.
16 """
18 def __init__(self):
19 self.stored = {}
21 def append(self, name, cl):
22 """
23 Adds a class into `globals()` to enable pickling on dynamic
24 classes.
25 """
26 if name in self.stored:
27 warnings.warn( # pragma: no cover
28 "Class %r overwritten in\n%r\n---\n%r" % (
29 name, ", ".join(sorted(self.stored)), cl),
30 RuntimeWarning)
31 self.stored[name] = cl
32 globals()[name] = cl
35_created_classes_inst = _created_classes()
38class wrapper_onnxnumpy:
39 """
40 Intermediate wrapper to store a pointer
41 on the compiler (type: @see cl OnnxNumpyCompiler).
43 :param compiled: instance of @see cl OnnxNumpyCompiler
45 .. versionadded:: 0.6
46 """
48 def __init__(self, compiled):
49 self.compiled = compiled
51 def __call__(self, *args, **kwargs):
52 """
53 Calls the compiled function with arguments `args`.
54 """
55 from .onnx_variable import OnnxVar
56 try:
57 return self.compiled(*args, **kwargs)
58 except (TypeError, RuntimeError, ValueError) as e:
59 if any(map(lambda a: isinstance(a, OnnxVar), args)):
60 return self.__class__.__fct__( # pylint: disable=E1101
61 *args, **kwargs)
62 raise RuntimeError(
63 "Unable to call the compiled version, args is %r. "
64 "kwargs=%r." % ([type(a) for a in args], kwargs)) from e
66 def __getstate__(self):
67 """
68 Serializes everything but the function which generates
69 the ONNX graph, not needed anymore.
70 """
71 return dict(compiled=self.compiled)
73 def __setstate__(self, state):
74 """
75 Serializes everything but the function which generates
76 the ONNX graph, not needed anymore.
77 """
78 self.compiled = state['compiled']
80 def to_onnx(self, **kwargs):
81 """
82 Returns the ONNX graph for the wrapped function.
83 It takes additional arguments to distinguish between multiple graphs.
84 This happens when a function needs to support multiple type.
86 :return: ONNX graph
87 """
88 return self.compiled.to_onnx(**kwargs)
91def onnxnumpy(op_version=None, runtime=None, signature=None):
92 """
93 Decorator to declare a function implemented using
94 :epkg:`numpy` syntax but executed with :epkg:`ONNX`
95 operators.
97 :param op_version: :epkg:`ONNX` opset version
98 :param runtime: `'onnxruntime'` or one implemented by
99 @see cl OnnxInference
100 :param signature: it should be used when the function
101 is not annoatated.
103 Equivalent to `onnxnumpy(arg)(foo)`.
105 .. versionadded:: 0.6
106 """
107 def decorator_fct(fct):
108 compiled = OnnxNumpyCompiler(
109 fct, op_version=op_version, runtime=runtime,
110 signature=signature)
111 name = "onnxnumpy_%s_%s_%s" % (fct.__name__, str(op_version), runtime)
112 newclass = type(
113 name, (wrapper_onnxnumpy,),
114 {'__doc__': fct.__doc__, '__name__': name, '__fct__': fct})
115 _created_classes_inst.append(name, newclass)
116 return newclass(compiled)
117 return decorator_fct
120def onnxnumpy_default(fct):
121 """
122 Decorator with options to declare a function implemented
123 using :epkg:`numpy` syntax but executed with :epkg:`ONNX`
124 operators.
126 :param fct: function to wrap
128 .. versionadded:: 0.6
129 """
130 return onnxnumpy()(fct)
133class wrapper_onnxnumpy_np:
134 """
135 Intermediate wrapper to store a pointer
136 on the compiler (type: @see cl OnnxNumpyCompiler)
137 supporting multiple signatures.
139 .. versionadded:: 0.6
140 """
142 def __init__(self, **kwargs):
143 self.fct = kwargs['fct']
144 self.signature = kwargs['signature']
145 self.fctsig = kwargs.get('fctsig', None)
146 self.args, self.kwargs = get_args_kwargs(
147 self.fct,
148 0 if self.signature is None else self.signature.n_optional)
149 self.data = kwargs
150 self.signed_compiled = {}
152 def __getstate__(self):
153 """
154 Serializes everything but the function which generates
155 the ONNX graph, not needed anymore.
156 """
157 data_copy = {k: v for k, v in self.data.items() if k != 'fct'}
158 return dict(signature=self.signature, args=self.args,
159 kwargs=self.kwargs, data=data_copy,
160 signed_compiled=self.signed_compiled)
162 def __setstate__(self, state):
163 """
164 Restores serialized data.
165 """
166 for k, v in state.items():
167 setattr(self, k, v)
169 def __getitem__(self, dtype):
170 """
171 Returns the instance of @see cl wrapper_onnxnumpy
172 mapped to *dtype*.
174 :param dtype: numpy dtype corresponding to the input dtype
175 of the function
176 :return: instance of @see cl wrapper_onnxnumpy
177 """
178 if not isinstance(dtype, FctVersion):
179 raise TypeError( # pragma: no cover
180 "dtype must be of type 'FctVersion' not %s: %s." % (
181 type(dtype), dtype))
182 if dtype not in self.signed_compiled:
183 self._populate(dtype)
184 key = dtype
185 else:
186 key = dtype
187 return self.signed_compiled[key]
189 def __call__(self, *args, **kwargs):
190 """
191 Calls the compiled function assuming the type of the first
192 tensor in *args* defines the templated version of the function
193 to convert into *ONNX*.
194 """
195 from .onnx_variable import OnnxVar
196 if len(self.kwargs) == 0:
197 others = None
198 else:
199 others = tuple(kwargs.get(k, self.kwargs[k]) for k in self.kwargs)
200 try:
201 key = FctVersion( # pragma: no cover
202 tuple(a if (a is None or hasattr(a, 'fit'))
203 else a.dtype.type for a in args),
204 others)
205 return self[key](*args)
206 except AttributeError as e:
207 if any(map(lambda a: isinstance(a, OnnxVar), args)):
208 return self.__class__.__fct__( # pylint: disable=E1101
209 *args, **kwargs)
210 raise RuntimeError(
211 "Unable to call the compiled version, args is %r. "
212 "kwargs=%r." % ([type(a) for a in args], kwargs)) from e
214 def _populate(self, version):
215 """
216 Creates the appropriate runtime for function *fct*
217 """
218 compiled = OnnxNumpyCompiler(
219 fct=self.data["fct"], op_version=self.data["op_version"],
220 runtime=self.data["runtime"], signature=self.data["signature"],
221 version=version, fctsig=self.data.get('fctsig', None))
222 name = "onnxnumpy_np_%s_%s_%s_%s" % (
223 self.data["fct"].__name__, str(self.data["op_version"]),
224 self.data["runtime"], version.as_string())
225 newclass = type(
226 name, (wrapper_onnxnumpy,),
227 {'__doc__': self.data["fct"].__doc__, '__name__': name})
229 self.signed_compiled[version] = newclass(compiled)
231 def _validate_onnx_data(self, X):
232 return X
234 def to_onnx(self, **kwargs):
235 """
236 Returns the ONNX graph for the wrapped function.
237 It takes additional arguments to distinguish between multiple graphs.
238 This happens when a function needs to support multiple type.
240 :return: ONNX graph
241 """
242 if len(self.signed_compiled) == 0:
243 raise RuntimeError( # pragma: no cover
244 "No ONNX graph was compiled.")
245 if len(kwargs) == 0 and len(self.signed_compiled) == 1:
246 # We take the only one.
247 key = list(self.signed_compiled)[0]
248 cpl = self.signed_compiled[key]
249 return cpl.to_onnx()
250 if len(kwargs) == 0:
251 raise ValueError(
252 "There are multiple compiled ONNX graphs associated "
253 "with keys %r (add key=...)." % list(self.signed_compiled))
254 if list(kwargs) != ['key']:
255 raise ValueError(
256 "kwargs should contain one parameter key=... but "
257 "it is %r." % kwargs)
258 key = kwargs['key']
259 if key in self.signed_compiled:
260 return self.signed_compiled[key].compiled.onnx_
261 found = []
262 for k, v in self.signed_compiled.items():
263 if k.args == key:
264 found.append((k, v))
265 elif isinstance(key, tuple) and k.args == key:
266 found.append((k, v))
267 elif k.args == (key, ) * len(k.args):
268 found.append((k, v))
269 if len(found) == 1:
270 return found[0][1].compiled.onnx_
271 raise ValueError(
272 "Unable to find signature with key=%r among %r found=%r." % (
273 key, list(self.signed_compiled), found))
276def onnxnumpy_np(op_version=None, runtime=None, signature=None):
277 """
278 Decorator to declare a function implemented using
279 :epkg:`numpy` syntax but executed with :epkg:`ONNX`
280 operators.
282 :param op_version: :epkg:`ONNX` opset version
283 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference
284 :param signature: it should be used when the function
285 is not annoatated.
287 Equivalent to `onnxnumpy(arg)(foo)`.
289 .. versionadded:: 0.6
290 """
291 def decorator_fct(fct):
292 name = "onnxnumpy_nb_%s_%s_%s" % (
293 fct.__name__, str(op_version), runtime)
294 newclass = type(
295 name, (wrapper_onnxnumpy_np,), {
296 '__doc__': fct.__doc__,
297 '__name__': name,
298 '__getstate__': wrapper_onnxnumpy_np.__getstate__,
299 '__setstate__': wrapper_onnxnumpy_np.__setstate__,
300 '__fct__': fct})
301 _created_classes_inst.append(name, newclass)
302 return newclass(
303 fct=fct, op_version=op_version, runtime=runtime,
304 signature=signature)
306 return decorator_fct