Coverage for mlprodict/npy/onnx_numpy_wrapper.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

109 statements  

1""" 

2@file 

3@brief Wraps :epkg:`numpy` functions into :epkg:`onnx`. 

4 

5.. versionadded:: 0.6 

6""" 

7import warnings 

8from .onnx_version import FctVersion 

9from .onnx_numpy_annotation import get_args_kwargs 

10from .onnx_numpy_compiler import OnnxNumpyCompiler 

11 

12 

13class _created_classes: 

14 """ 

15 Class to store all dynamic classes created by wrappers. 

16 """ 

17 

18 def __init__(self): 

19 self.stored = {} 

20 

21 def append(self, name, cl): 

22 """ 

23 Adds a class into `globals()` to enable pickling on dynamic 

24 classes. 

25 """ 

26 if name in self.stored: 

27 warnings.warn( # pragma: no cover 

28 "Class %r overwritten in\n%r\n---\n%r" % ( 

29 name, ", ".join(sorted(self.stored)), cl), 

30 RuntimeWarning) 

31 self.stored[name] = cl 

32 globals()[name] = cl 

33 

34 

35_created_classes_inst = _created_classes() 

36 

37 

38class wrapper_onnxnumpy: 

39 """ 

40 Intermediate wrapper to store a pointer 

41 on the compiler (type: @see cl OnnxNumpyCompiler). 

42 

43 :param compiled: instance of @see cl OnnxNumpyCompiler 

44 

45 .. versionadded:: 0.6 

46 """ 

47 

48 def __init__(self, compiled): 

49 self.compiled = compiled 

50 

51 def __call__(self, *args, **kwargs): 

52 """ 

53 Calls the compiled function with arguments `args`. 

54 """ 

55 from .onnx_variable import OnnxVar 

56 try: 

57 return self.compiled(*args, **kwargs) 

58 except (TypeError, RuntimeError, ValueError) as e: 

59 if any(map(lambda a: isinstance(a, OnnxVar), args)): 

60 return self.__class__.__fct__( # pylint: disable=E1101 

61 *args, **kwargs) 

62 raise RuntimeError( 

63 "Unable to call the compiled version, args is %r. " 

64 "kwargs=%r." % ([type(a) for a in args], kwargs)) from e 

65 

66 def __getstate__(self): 

67 """ 

68 Serializes everything but the function which generates 

69 the ONNX graph, not needed anymore. 

70 """ 

71 return dict(compiled=self.compiled) 

72 

73 def __setstate__(self, state): 

74 """ 

75 Serializes everything but the function which generates 

76 the ONNX graph, not needed anymore. 

77 """ 

78 self.compiled = state['compiled'] 

79 

80 def to_onnx(self, **kwargs): 

81 """ 

82 Returns the ONNX graph for the wrapped function. 

83 It takes additional arguments to distinguish between multiple graphs. 

84 This happens when a function needs to support multiple type. 

85 

86 :return: ONNX graph 

87 """ 

88 return self.compiled.to_onnx(**kwargs) 

89 

90 

91def onnxnumpy(op_version=None, runtime=None, signature=None): 

92 """ 

93 Decorator to declare a function implemented using 

94 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

95 operators. 

96 

97 :param op_version: :epkg:`ONNX` opset version 

98 :param runtime: `'onnxruntime'` or one implemented by 

99 @see cl OnnxInference 

100 :param signature: it should be used when the function 

101 is not annoatated. 

102 

103 Equivalent to `onnxnumpy(arg)(foo)`. 

104 

105 .. versionadded:: 0.6 

106 """ 

107 def decorator_fct(fct): 

108 compiled = OnnxNumpyCompiler( 

109 fct, op_version=op_version, runtime=runtime, 

110 signature=signature) 

111 name = "onnxnumpy_%s_%s_%s" % (fct.__name__, str(op_version), runtime) 

112 newclass = type( 

113 name, (wrapper_onnxnumpy,), 

114 {'__doc__': fct.__doc__, '__name__': name, '__fct__': fct}) 

115 _created_classes_inst.append(name, newclass) 

116 return newclass(compiled) 

117 return decorator_fct 

118 

119 

120def onnxnumpy_default(fct): 

121 """ 

122 Decorator with options to declare a function implemented 

123 using :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

124 operators. 

125 

126 :param fct: function to wrap 

127 

128 .. versionadded:: 0.6 

129 """ 

130 return onnxnumpy()(fct) 

131 

132 

133class wrapper_onnxnumpy_np: 

134 """ 

135 Intermediate wrapper to store a pointer 

136 on the compiler (type: @see cl OnnxNumpyCompiler) 

137 supporting multiple signatures. 

138 

139 .. versionadded:: 0.6 

140 """ 

141 

142 def __init__(self, **kwargs): 

143 self.fct = kwargs['fct'] 

144 self.signature = kwargs['signature'] 

145 self.fctsig = kwargs.get('fctsig', None) 

146 self.args, self.kwargs = get_args_kwargs( 

147 self.fct, 

148 0 if self.signature is None else self.signature.n_optional) 

149 self.data = kwargs 

150 self.signed_compiled = {} 

151 

152 def __getstate__(self): 

153 """ 

154 Serializes everything but the function which generates 

155 the ONNX graph, not needed anymore. 

156 """ 

157 data_copy = {k: v for k, v in self.data.items() if k != 'fct'} 

158 return dict(signature=self.signature, args=self.args, 

159 kwargs=self.kwargs, data=data_copy, 

160 signed_compiled=self.signed_compiled) 

161 

162 def __setstate__(self, state): 

163 """ 

164 Restores serialized data. 

165 """ 

166 for k, v in state.items(): 

167 setattr(self, k, v) 

168 

169 def __getitem__(self, dtype): 

170 """ 

171 Returns the instance of @see cl wrapper_onnxnumpy 

172 mapped to *dtype*. 

173 

174 :param dtype: numpy dtype corresponding to the input dtype 

175 of the function 

176 :return: instance of @see cl wrapper_onnxnumpy 

177 """ 

178 if not isinstance(dtype, FctVersion): 

179 raise TypeError( # pragma: no cover 

180 "dtype must be of type 'FctVersion' not %s: %s." % ( 

181 type(dtype), dtype)) 

182 if dtype not in self.signed_compiled: 

183 self._populate(dtype) 

184 key = dtype 

185 else: 

186 key = dtype 

187 return self.signed_compiled[key] 

188 

189 def __call__(self, *args, **kwargs): 

190 """ 

191 Calls the compiled function assuming the type of the first 

192 tensor in *args* defines the templated version of the function 

193 to convert into *ONNX*. 

194 """ 

195 from .onnx_variable import OnnxVar 

196 if len(self.kwargs) == 0: 

197 others = None 

198 else: 

199 others = tuple(kwargs.get(k, self.kwargs[k]) for k in self.kwargs) 

200 try: 

201 key = FctVersion( # pragma: no cover 

202 tuple(a if (a is None or hasattr(a, 'fit')) 

203 else a.dtype.type for a in args), 

204 others) 

205 return self[key](*args) 

206 except AttributeError as e: 

207 if any(map(lambda a: isinstance(a, OnnxVar), args)): 

208 return self.__class__.__fct__( # pylint: disable=E1101 

209 *args, **kwargs) 

210 raise RuntimeError( 

211 "Unable to call the compiled version, args is %r. " 

212 "kwargs=%r." % ([type(a) for a in args], kwargs)) from e 

213 

214 def _populate(self, version): 

215 """ 

216 Creates the appropriate runtime for function *fct* 

217 """ 

218 compiled = OnnxNumpyCompiler( 

219 fct=self.data["fct"], op_version=self.data["op_version"], 

220 runtime=self.data["runtime"], signature=self.data["signature"], 

221 version=version, fctsig=self.data.get('fctsig', None)) 

222 name = "onnxnumpy_np_%s_%s_%s_%s" % ( 

223 self.data["fct"].__name__, str(self.data["op_version"]), 

224 self.data["runtime"], version.as_string()) 

225 newclass = type( 

226 name, (wrapper_onnxnumpy,), 

227 {'__doc__': self.data["fct"].__doc__, '__name__': name}) 

228 

229 self.signed_compiled[version] = newclass(compiled) 

230 

231 def _validate_onnx_data(self, X): 

232 return X 

233 

234 def to_onnx(self, **kwargs): 

235 """ 

236 Returns the ONNX graph for the wrapped function. 

237 It takes additional arguments to distinguish between multiple graphs. 

238 This happens when a function needs to support multiple type. 

239 

240 :return: ONNX graph 

241 """ 

242 if len(self.signed_compiled) == 0: 

243 raise RuntimeError( # pragma: no cover 

244 "No ONNX graph was compiled.") 

245 if len(kwargs) == 0 and len(self.signed_compiled) == 1: 

246 # We take the only one. 

247 key = list(self.signed_compiled)[0] 

248 cpl = self.signed_compiled[key] 

249 return cpl.to_onnx() 

250 if len(kwargs) == 0: 

251 raise ValueError( 

252 "There are multiple compiled ONNX graphs associated " 

253 "with keys %r (add key=...)." % list(self.signed_compiled)) 

254 if list(kwargs) != ['key']: 

255 raise ValueError( 

256 "kwargs should contain one parameter key=... but " 

257 "it is %r." % kwargs) 

258 key = kwargs['key'] 

259 if key in self.signed_compiled: 

260 return self.signed_compiled[key].compiled.onnx_ 

261 found = [] 

262 for k, v in self.signed_compiled.items(): 

263 if k.args == key: 

264 found.append((k, v)) 

265 elif isinstance(key, tuple) and k.args == key: 

266 found.append((k, v)) 

267 elif k.args == (key, ) * len(k.args): 

268 found.append((k, v)) 

269 if len(found) == 1: 

270 return found[0][1].compiled.onnx_ 

271 raise ValueError( 

272 "Unable to find signature with key=%r among %r found=%r." % ( 

273 key, list(self.signed_compiled), found)) 

274 

275 

276def onnxnumpy_np(op_version=None, runtime=None, signature=None): 

277 """ 

278 Decorator to declare a function implemented using 

279 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

280 operators. 

281 

282 :param op_version: :epkg:`ONNX` opset version 

283 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

284 :param signature: it should be used when the function 

285 is not annoatated. 

286 

287 Equivalent to `onnxnumpy(arg)(foo)`. 

288 

289 .. versionadded:: 0.6 

290 """ 

291 def decorator_fct(fct): 

292 name = "onnxnumpy_nb_%s_%s_%s" % ( 

293 fct.__name__, str(op_version), runtime) 

294 newclass = type( 

295 name, (wrapper_onnxnumpy_np,), { 

296 '__doc__': fct.__doc__, 

297 '__name__': name, 

298 '__getstate__': wrapper_onnxnumpy_np.__getstate__, 

299 '__setstate__': wrapper_onnxnumpy_np.__setstate__, 

300 '__fct__': fct}) 

301 _created_classes_inst.append(name, newclass) 

302 return newclass( 

303 fct=fct, op_version=op_version, runtime=runtime, 

304 signature=signature) 

305 

306 return decorator_fct