Coverage for mlprodict/npy/onnx_numpy_compiler.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

204 statements  

1""" 

2@file 

3@brief Implements :epkg:`numpy` functions with onnx and a runtime. 

4 

5.. versionadded:: 0.6 

6""" 

7import inspect 

8import logging 

9from typing import Any 

10import numpy 

11from ..onnx_tools.optim._main_onnx_optim import onnx_optimisations 

12from .onnx_version import FctVersion 

13from .onnx_numpy_annotation import get_args_kwargs 

14from .xop_variable import Variable 

15from .xop import OnnxOperator, OnnxOperatorTuple 

16 

17 

18logger = logging.getLogger('xop') 

19 

20 

21class OnnxNumpyFunction: 

22 """ 

23 Class wrapping a function build with 

24 @see cl OnnxNumpyCompiler. 

25 

26 .. versionadded:: 0.6 

27 """ 

28 

29 def __init__(self, compiler, rt, inputs, outputs, 

30 n_optional, n_variables): 

31 if any(map(lambda n: not isinstance(n, Variable), inputs)): 

32 raise TypeError( # pragma: no cover 

33 "All inputs must be of type Variable: %r." % (inputs, )) 

34 if any(map(lambda n: not isinstance(n, Variable), outputs)): 

35 raise TypeError( # pragma: no cover 

36 "All outputs must be of type Variable: %r." % (outputs, )) 

37 self.compiler = compiler 

38 self.inputs = inputs 

39 self.outputs = outputs 

40 self.rt = rt 

41 self.n_optional = n_optional 

42 self.n_variables = n_variables 

43 if n_optional < 0: 

44 raise RuntimeError( # pragma: no cover 

45 "Wrong configuration, n_optional %r must be >= 0." 

46 "" % n_optional) 

47 if n_optional >= len(inputs): 

48 raise RuntimeError( # pragma: no cover 

49 "Wrong configuration, n_optional %r must be >= %r " 

50 "the number of inputs." % (n_optional, len(inputs))) 

51 

52 def _check_(self, *args, **kwargs): 

53 if self.n_variables > 0: 

54 return 

55 if (len(args) < len(self.inputs) - self.n_optional or 

56 len(args) > len(self.inputs)): 

57 raise RuntimeError( # pragma: no cover 

58 "Unexpected number of inputs %d. It should be in " 

59 "[%r, %r] len(args)=%d n_optional=%d n_variables=%d" 

60 "\nargs=%s\nkwargs=%s\ninputs=%s" % ( 

61 len(args), len(self.inputs) - self.n_optional, 

62 len(args), self.n_optional, self.n_variables, 

63 len(self.inputs), args, kwargs, self.inputs)) 

64 

65 

66class OnnxNumpyFunctionOnnxInference(OnnxNumpyFunction): 

67 """ 

68 Overwrites @see cl OnnxNumpyFunction to run an instance of 

69 @see cl OnnxInference. 

70 

71 .. versionadded:: 0.6 

72 """ 

73 

74 def __call__(self, *args, **kwargs): 

75 self._check_(*args, **kwargs) 

76 inp = {k.name: a for k, a in zip(self.inputs, args)} 

77 out = self.rt.run(inp, **kwargs) 

78 if len(out) != len(self.outputs): 

79 raise RuntimeError( # pragma: no cover 

80 "Unexpected number of outputs %d instead of %d." % ( 

81 len(out), len(self.outputs))) 

82 return tuple([out[o.name] for o in self.outputs]) 

83 

84 

85class OnnxNumpyFunctionInferenceSession(OnnxNumpyFunction): 

86 """ 

87 Overwrites @see cl OnnxNumpyFunction to run an instance of 

88 `InferenceSession` from :epkg:`onnxruntime`. 

89 

90 .. versionadded:: 0.6 

91 """ 

92 

93 def __call__(self, *args, **kwargs): 

94 self._check_(*args, **kwargs) 

95 if len(kwargs) > 0: 

96 raise RuntimeError( # pragma: no cover 

97 "kwargs is not used but it is not empty: %r." % kwargs) 

98 inp = {k.name: a for k, a in zip(self.inputs, args)} 

99 out = self.rt.run(None, inp) 

100 

101 if len(out) != len(self.outputs): 

102 raise RuntimeError( # pragma: no cover 

103 "Unexpected number of outputs %d instead of %d." % ( 

104 len(out), len(self.outputs))) 

105 return tuple(out) 

106 

107 

108class OnnxNumpyCompiler: 

109 """ 

110 Implements a class which runs onnx graph. 

111 

112 :param fct: a function with annotations which returns an ONNX graph, 

113 it can also be an ONNX graph. 

114 :param op_version: :epkg:`ONNX` opset to use, None 

115 for the latest one 

116 :param runtime: runtime to choose to execute the onnx graph, 

117 `python`, `onnxruntime`, `onnxruntime1` 

118 :param signature: used when the function is not annotated 

119 :param version: the same function can be instantiated with 

120 different type, this parameter is None or a numpy type 

121 if the signature allows multiple types, it must an instance 

122 of type @see cl FctVersion 

123 :param fctsig: function used to overwrite the fct signature 

124 in case this one is using `*args, **kwargs` 

125 

126 .. versionadded:: 0.6 

127 """ 

128 

129 def __init__(self, fct, op_version=None, runtime=None, signature=None, 

130 version=None, fctsig=None): 

131 if version is not None and not isinstance(version, FctVersion): 

132 raise TypeError( # pragma: no cover 

133 "version must be of Type 'FctVersion' not %s - %s" 

134 "." % (type(version), version)) 

135 self.fctsig = fctsig 

136 if op_version is None: 

137 from .. import __max_supported_opset__ 

138 op_version = __max_supported_opset__ 

139 if hasattr(fct, 'SerializeToString'): 

140 self.fct_ = None 

141 self.onnx_ = fct 

142 else: 

143 self.fct_ = fct 

144 if not inspect.isfunction(fct): 

145 raise TypeError( # pragma: no cover 

146 "Unexpected type for fct=%r, it must be a " 

147 "function." % type(fct)) 

148 self.onnx_ = None 

149 self.onnx_ = self._to_onnx( 

150 op_version=op_version, signature=signature, 

151 version=version) 

152 self.runtime_ = self._build_runtime( 

153 op_version=op_version, runtime=runtime, 

154 signature=signature, version=version) 

155 ann = self._parse_annotation(signature=signature, version=version) 

156 inputs, outputs, kwargs, n_optional, n_variables = ann 

157 n_opt = 0 if signature is None else signature.n_optional 

158 args, kwargs2 = get_args_kwargs(self.fctsig or self.fct_, n_opt) 

159 self.meta_ = dict(op_version=op_version, runtime=runtime, 

160 signature=signature, version=version, 

161 inputs=inputs, outputs=outputs, 

162 kwargs=kwargs, n_optional=n_optional, 

163 n_variables=n_variables, 

164 args=args, kwargs2=kwargs2, 

165 annotations=self.fct_.__annotations__) 

166 

167 def __getstate__(self): 

168 """ 

169 Serializes everything but function `fct_`. 

170 Function `fct_` is used to build the onnx graph 

171 and is not needed anymore. 

172 """ 

173 return dict(onnx_=self.onnx_, meta_=self.meta_) 

174 

175 def __setstate__(self, state): 

176 """ 

177 Restores serialized data. 

178 """ 

179 for k, v in state.items(): 

180 setattr(self, k, v) 

181 self.runtime_ = self._build_runtime( 

182 op_version=self.meta_['op_version'], 

183 runtime=self.meta_['runtime'], 

184 signature=self.meta_['signature'], 

185 version=self.meta_['version']) 

186 

187 def __repr__(self): 

188 "usual" 

189 if self.fct_ is not None: 

190 return "%s(%s)" % (self.__class__.__name__, repr(self.fct_)) 

191 if self.onnx_ is not None: 

192 return "%s(%s)" % (self.__class__.__name__, "... ONNX ... ") 

193 raise NotImplementedError( # pragma: no cover 

194 "fct_ and onnx_ are empty.") 

195 

196 def _to_onnx_shape(self, shape): 

197 if shape is Any or shape is Ellipsis: 

198 shape = None 

199 elif isinstance(shape, tuple): 

200 shape = [None if s is Any or s is Ellipsis else s 

201 for s in shape] 

202 else: 

203 raise RuntimeError( # pragma: no cover 

204 "Unexpected annotated shape %r." % shape) 

205 return shape 

206 

207 def _parse_annotation(self, signature, version): 

208 """ 

209 Returns the annotations for function `fct_`. 

210 

211 :param signature: needed if the annotation is missing, 

212 then version might be needed to specify which type 

213 to use if the signature allows many 

214 :param version: version inside the many signatures possible 

215 :return: *tuple(inputs, outputs, kwargs)*, each of them 

216 is a list of tuple with the name and the dtype, 

217 *kwargs* is the list of additional parameters 

218 """ 

219 n_opt = 0 if signature is None else signature.n_optional 

220 if hasattr(self, 'meta_'): 

221 args, kwargs = self.meta_['args'], self.meta_['kwargs2'] 

222 else: 

223 args, kwargs = get_args_kwargs(self.fctsig or self.fct_, n_opt) 

224 if version is not None: 

225 nv = len(version) - len(args) - n_opt 

226 if (signature is not None and not 

227 signature.n_variables and nv > len(kwargs)): 

228 raise RuntimeError( # pragma: no cover 

229 "Mismatch (%d - %d - %d ? %d) between version=%r and kwargs=%r for " 

230 "function %r, optional argument is %d, " 

231 "signature=%r." % ( 

232 len(version), len(args), n_opt, len(kwargs), 

233 version, kwargs, self.fct_, 

234 signature.n_variables, signature)) 

235 vvers = {} if version.kwargs is None else version.kwargs 

236 up = {} 

237 for k, v in zip(kwargs, vvers): 

238 up[k] = v 

239 kwargs = kwargs.copy() 

240 kwargs.update(up) 

241 

242 for k, v in kwargs.items(): 

243 if isinstance(v, (type, numpy.dtype)): 

244 raise RuntimeError( # pragma: no cover 

245 "Unexpected value for argument %r: %r from %r." % ( 

246 k, v, kwargs)) 

247 

248 if signature is not None: 

249 inputs, kwargs, outputs, n_optional, n_variables = ( 

250 signature.get_inputs_outputs(args, kwargs, version)) 

251 inputs = [Variable(i[0], i[1]) for i in inputs] 

252 outputs = [Variable(i[0], i[1]) for i in outputs] 

253 return inputs, outputs, kwargs, n_optional, n_variables 

254 

255 def _possible_names(): 

256 yield 'y' 

257 yield 'z' # pragma: no cover 

258 yield 'o' # pragma: no cover 

259 for i in range(0, 10000): # pragma: no cover 

260 yield 'o%d' % i 

261 

262 if hasattr(self, 'meta_'): 

263 annotations = self.meta_['annotations'] 

264 else: 

265 annotations = self.fct_.__annotations__ 

266 inputs = [] 

267 outputs = [] 

268 for a in args: 

269 if a == "op_version": 

270 continue 

271 if a not in annotations: 

272 raise RuntimeError( # pragma: no cover 

273 "Unable to find annotation for argument %r. " 

274 "You should annotate the arguments and the results " 

275 "or specify a signature." % a) 

276 ann = annotations[a] 

277 shape, dtype = ann.__args__ 

278 shape = self._to_onnx_shape(shape) 

279 inputs.append(Variable(a, dtype, shape=shape)) 

280 

281 ret = annotations['return'] 

282 names_in = set(inp.name for inp in inputs) 

283 

284 if isinstance(ret, tuple): 

285 # multiple outputs 

286 names_none = set() 

287 for shape_dtype in ret: 

288 shape, dtype = shape_dtype.__args__ 

289 shape = self._to_onnx_shape(shape) 

290 name_out = None 

291 for name in _possible_names(): 

292 if name not in names_in and name not in names_none: 

293 name_out = name 

294 break 

295 outputs.append(Variable(name_out, dtype, shape=shape)) 

296 names_none.add(name_out) 

297 return (inputs, outputs, kwargs, 0, 

298 signature.n_variables if signature is not None else False) 

299 

300 # single outputs 

301 shape, dtype = ret.__args__ 

302 shape = self._to_onnx_shape(shape) 

303 name_out = None 

304 for name in _possible_names(): 

305 if name not in names_in: 

306 name_out = name 

307 break 

308 outputs.append(Variable(name_out, dtype, shape=shape)) 

309 return (inputs, outputs, kwargs, 0, 

310 signature.n_variables if signature is not None else False) 

311 

312 def _find_hidden_algebras(self, onx_var, onx_algebra): 

313 """ 

314 Subgraph are using inputs not linked to the others nodes. 

315 This function retrieves them as they are stored in 

316 attributes `alg_hidden_var_`. The function looks into every 

317 node linked to the inputs and their predecessors. 

318 

319 :param onx_var: @see cl OnnxVar 

320 :param onx_algebra: OnnxOperator 

321 :return: tuple(dictionary `{id(obj): (var, obj)}`, 

322 all instance of @see cl OnnxVarGraph) 

323 """ 

324 keep_hidden = {} 

325 var_graphs = [] 

326 stack = [onx_var] 

327 while len(stack) > 0: 

328 var = stack.pop() 

329 hidden = getattr(var, 'alg_hidden_var_', None) 

330 if hidden is not None: 

331 if any(map(lambda x: len(x) > 0, 

332 var.alg_hidden_var_inputs.values())): 

333 keep_hidden.update(hidden) 

334 var_graphs.append(var) 

335 if hasattr(var, 'inputs'): 

336 for inp in var.inputs: 

337 stack.append(inp) 

338 return keep_hidden, var_graphs 

339 

340 def _to_onnx(self, op_version=None, signature=None, version=None): 

341 """ 

342 Returns the onnx graph produced by function `fct_`. 

343 """ 

344 if self.onnx_ is None and self.fct_ is not None: 

345 from .onnx_variable import OnnxVar 

346 logger.debug('OnnxNumpyCompiler._to_onnx(op_version=%r, ' 

347 'signature=%r, version=%r)', 

348 op_version, signature, version) 

349 inputs, outputs, kwargs, n_optional, n_variables = ( # pylint: disable=W0612 

350 self._parse_annotation( 

351 signature=signature, version=version)) 

352 if ((signature is None or not signature.n_variables) and 

353 isinstance(version, tuple) and 

354 len(inputs) > len(version)): 

355 raise NotImplementedError( # pragma: no cover 

356 "Mismatch between additional parameters %r " 

357 "(n_optional=%r) and version %r for function %r from %r." 

358 "" % (kwargs, n_optional, version, self.fct_, 

359 getattr(self.fct_, '__module__', None))) 

360 names_in = [oi.name for oi in inputs] 

361 names_out = [oi.name for oi in outputs] 

362 names_var = [OnnxVar(n, dtype=dt.dtype) 

363 for n, dt in zip(names_in, inputs)] 

364 

365 logger.debug('OnnxNumpyCompiler._to_onnx:names_in=%r', names_in) 

366 logger.debug('OnnxNumpyCompiler._to_onnx:names_out=%r', names_out) 

367 

368 if 'op_version' in self.fct_.__code__.co_varnames: 

369 onx_var = None 

370 onx_algebra = self.fct_( 

371 *names_in, op_version=op_version, **kwargs) 

372 else: 

373 onx_var = self.fct_(*names_var, **kwargs) 

374 if not hasattr(onx_var, 'to_algebra'): 

375 raise TypeError( # pragma: no cover 

376 "The function %r to convert must return an instance of " 

377 "OnnxVar but returns type %r." % (self.fct_, type(onx_var))) 

378 onx_algebra = onx_var.to_algebra(op_version=op_version) 

379 

380 logger.debug('OnnxNumpyCompiler._to_onnx:onx_var=%r', 

381 type(onx_var)) 

382 logger.debug('OnnxNumpyCompiler._to_onnx:onx_algebra=%r', 

383 type(onx_algebra)) 

384 

385 if not isinstance(onx_algebra, (OnnxOperator, OnnxOperatorTuple)): 

386 raise TypeError( # pragma: no cover 

387 "Unexpected type for onx_algebra %r " 

388 "(It should be OnnxOperator or OnnxOperatorItem), " 

389 "function is %r." % (type(onx_algebra), self.fct_)) 

390 hidden_algebras, var_graphs = self._find_hidden_algebras( 

391 onx_var, onx_algebra) 

392 if len(hidden_algebras) > 0: 

393 logger.debug( # pragma: no cover 

394 'OnnxNumpyCompiler._to_onnx:len(hidden_algebras)=%r', 

395 len(hidden_algebras)) 

396 # print('----1', len(var_graphs)) 

397 # for gr in var_graphs: 

398 # print(type(gr), dir(gr)) 

399 # print('----2', len(hidden_algebras)) 

400 # for k, v in hidden_algebras.items(): 

401 # print("*", type(v.alg_), dir(v.alg_)) 

402 # #import pprint 

403 # #pprint.pprint(dir(v.alg_)) 

404 raise NotImplementedError( # pragma: no cover 

405 "Subgraphs only support constants (operator If, Loop, " 

406 "Scan). hidden_algebras=%r var_graphs=%r" % ( 

407 hidden_algebras, var_graphs)) 

408 

409 if isinstance(onx_algebra, str): 

410 raise RuntimeError( # pragma: no cover 

411 "Unexpected str type %r." % onx_algebra) 

412 if isinstance(onx_algebra, tuple): 

413 raise NotImplementedError( # pragma: no cover 

414 "Not implemented when the function returns multiple results.") 

415 if hasattr(onx_algebra, 'to_onnx'): 

416 onx_algebra.output_names = [Variable(n) for n in names_out] 

417 onx = onx_algebra.to_onnx( 

418 inputs=inputs, target_opset=op_version, outputs=outputs) 

419 # optimisation 

420 onx_optimized = onnx_optimisations(onx) 

421 self.onnx_ = onx_optimized 

422 

423 if self.onnx_ is None: 

424 raise RuntimeError( # pragma: no cover 

425 "Unable to get the ONNX graph (class %r, fct_=%r)" % ( 

426 type(self), self.fct_)) 

427 return self.onnx_ 

428 

429 def to_onnx(self, **kwargs): 

430 """ 

431 Returns the ONNX graph for the wrapped function. 

432 It takes additional arguments to distinguish between multiple graphs. 

433 This happens when a function needs to support multiple type. 

434 

435 :return: ONNX graph 

436 """ 

437 if len(kwargs) > 0: 

438 raise NotImplementedError( # pragma: no cover 

439 "kwargs is not empty, this case is not implemented. " 

440 "kwargs=%r." % kwargs) 

441 if hasattr(self, 'onnx_'): 

442 return self.onnx_ 

443 raise NotImplementedError( # pragma: no cover 

444 "Attribute 'onnx_' is missing.") 

445 

446 def _build_runtime(self, op_version=None, runtime=None, 

447 signature=None, version=None): 

448 """ 

449 Creates the runtime for the :epkg:`ONNX` graph. 

450 

451 :param op_version: :epkg:`ONNX` opset to use, None 

452 for the latest one 

453 :param runtime: runtime to choose to execute the onnx graph, 

454 `python`, `onnxruntime`, `onnxruntime1` 

455 :param signature: used when the function is not annotated 

456 """ 

457 onx = self._to_onnx(op_version=op_version, signature=signature, 

458 version=version) 

459 inputs, outputs, _, n_optional, n_variables = self._parse_annotation( 

460 signature=signature, version=version) 

461 if runtime not in ('onnxruntime', 'onnxruntime-cuda'): 

462 from ..onnxrt import OnnxInference 

463 rt = OnnxInference(onx, runtime=runtime) 

464 self.rt_fct_ = OnnxNumpyFunctionOnnxInference( 

465 self, rt, inputs=inputs, outputs=outputs, 

466 n_optional=n_optional, n_variables=n_variables) 

467 else: 

468 from ..tools.ort_wrapper import InferenceSession 

469 rt = InferenceSession(onx.SerializeToString(), runtime=runtime) 

470 self.rt_fct_ = OnnxNumpyFunctionInferenceSession( 

471 self, rt, inputs=inputs, outputs=outputs, 

472 n_optional=n_optional, n_variables=n_variables) 

473 return self.rt_fct_ 

474 

475 def __call__(self, *args, **kwargs): 

476 """ 

477 Executes the function and returns the results. 

478 

479 :param args: arguments 

480 :return: results 

481 """ 

482 res = self.rt_fct_(*args, **kwargs) 

483 if len(res) == 1: 

484 return res[0] 

485 return res