Coverage for mlprodict/npy/onnx_numpy_annotation.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

174 statements  

1""" 

2@file 

3@brief :epkg:`numpy` annotations. 

4 

5.. versionadded:: 0.6 

6""" 

7import inspect 

8from collections import OrderedDict 

9from typing import TypeVar, Generic 

10import numpy 

11from .onnx_version import FctVersion 

12 

13try: 

14 numpy_bool = numpy.bool_ 

15except AttributeError: # pragma: no cover 

16 numpy_bool = bool 

17 

18try: 

19 numpy_str = numpy.str_ 

20except AttributeError: # pragma: no cover 

21 numpy_str = str 

22 

23Shape = TypeVar("Shape") 

24DType = TypeVar("DType") 

25 

26 

27all_dtypes = (numpy.float32, numpy.float64, 

28 numpy.int32, numpy.int64, 

29 numpy.uint32, numpy.uint64) 

30 

31 

32def get_args_kwargs(fct, n_optional): 

33 """ 

34 Extracts arguments and optional parameters of a function. 

35 

36 :param fct: function 

37 :param n_optional: number of arguments to consider as 

38 optional arguments and not parameters, this parameter skips 

39 the first *n_optional* paramerters 

40 :return: arguments, OrderedDict 

41 

42 Any optional argument ending with '_' is ignored. 

43 """ 

44 params = inspect.signature(fct).parameters 

45 if n_optional == 0: 

46 items = list(params.items()) 

47 args = [name for name, p in params.items() 

48 if p.default == inspect.Parameter.empty] 

49 else: 

50 items = [] 

51 args = [] 

52 for name, p in params.items(): 

53 if p.default == inspect.Parameter.empty: 

54 args.append(name) 

55 else: 

56 if n_optional > 0: 

57 args.append(name) 

58 n_optional -= 1 

59 else: 

60 items.append((name, p)) 

61 

62 kwargs = OrderedDict((name, p.default) for name, p in items 

63 if (p.default != inspect.Parameter.empty and 

64 name != 'op_version')) 

65 if args[0] == 'self': 

66 args = args[1:] 

67 kwargs['op_'] = None 

68 return args, kwargs 

69 

70 

71class NDArray(numpy.ndarray, Generic[Shape, DType]): 

72 """ 

73 Used to annotation ONNX numpy functions. 

74 

75 .. versionadded:: 0.6 

76 """ 

77 class ShapeType: 

78 "Stores shape information." 

79 

80 def __init__(self, params): 

81 self.__args__ = params 

82 

83 def __class_getitem__(cls, params): # pylint: disable=W0221,W0237 

84 "Overwrites this method." 

85 if not isinstance(params, tuple): 

86 params = (params,) # pragma: no cover 

87 return NDArray.ShapeType(params) 

88 

89 

90class _NDArrayAlias: 

91 """ 

92 Ancestor to custom signature. 

93 

94 :param dtypes: input dtypes 

95 :param dtypes_out: output dtypes 

96 :param n_optional: number of optional parameters, 0 by default 

97 :param nvars: True if the function allows an infinite number of inputs, 

98 this is incompatible with parameter *n_optional*. 

99 

100 *dtypes*, *dtypes_out* by default are a tuple of tuple: 

101 

102 * first dimension: type of every input 

103 * second dimension: list of types for one input 

104 

105 .. versionadded:: 0.6 

106 """ 

107 

108 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None, 

109 nvars=False): 

110 "constructor" 

111 if dtypes is None: 

112 raise ValueError("dtypes cannot be None.") # pragma: no cover 

113 if isinstance(dtypes, tuple) and len(dtypes) == 0: 

114 raise TypeError("dtypes must not be empty.") # pragma: no cover 

115 if isinstance(dtypes, tuple) and not isinstance(dtypes[0], tuple): 

116 dtypes = tuple(t if isinstance(t, str) else (t,) for t in dtypes) 

117 if isinstance(dtypes, str) and '_' in dtypes: 

118 dtypes, dtypes_out = dtypes.split('_') 

119 if not isinstance(dtypes, (tuple, list)): 

120 dtypes = (dtypes, ) 

121 

122 self.mapped_types = {} 

123 self.dtypes = _NDArrayAlias._process_type( 

124 dtypes, self.mapped_types, 0) 

125 if dtypes_out is None: 

126 self.dtypes_out = (self.dtypes[0], ) 

127 elif isinstance(dtypes_out, int): 

128 self.dtypes_out = (self.dtypes[dtypes_out], ) 

129 else: 

130 if not isinstance(dtypes_out, (tuple, list)): 

131 dtypes_out = (dtypes_out, ) 

132 self.dtypes_out = _NDArrayAlias._process_type( 

133 dtypes_out, self.mapped_types, 0) 

134 self.n_optional = 0 if n_optional is None else n_optional 

135 self.n_variables = nvars 

136 

137 if not isinstance(self.dtypes, tuple): 

138 raise TypeError( # pragma: no cover 

139 "self.dtypes must be a tuple not {}.".format(self.dtypes)) 

140 if (len(self.dtypes) == 0 or 

141 not isinstance(self.dtypes[0], tuple)): 

142 raise TypeError( # pragma: no cover 

143 "Type mismatch in self.dtypes: {}.".format(self.dtypes)) 

144 if (len(self.dtypes[0]) == 0 or 

145 isinstance(self.dtypes[0][0], tuple)): 

146 raise TypeError( # pragma: no cover 

147 "Type mismatch in self.dtypes: {}.".format(self.dtypes)) 

148 

149 if not isinstance(self.dtypes_out, tuple): 

150 raise TypeError( # pragma: no cover 

151 "self.dtypes_out must be a tuple not {}.".format(self.dtypes_out)) 

152 if (len(self.dtypes_out) == 0 or 

153 not isinstance(self.dtypes_out[0], tuple)): 

154 raise TypeError( # pragma: no cover 

155 "Type mismatch in self.dtypes_out={}, " 

156 "self.dtypes={}.".format(self.dtypes_out, self.dtypes)) 

157 if (len(self.dtypes_out[0]) == 0 or 

158 isinstance(self.dtypes_out[0][0], tuple)): 

159 raise TypeError( # pragma: no cover 

160 "Type mismatch in self.dtypes_out: {}.".format(self.dtypes_out)) 

161 

162 if self.n_variables and self.n_optional > 0: 

163 raise RuntimeError( # pragma: no cover 

164 "n_variables and n_optional cannot be positive at " 

165 "the same type.") 

166 

167 @staticmethod 

168 def _process_type(dtypes, mapped_types, index): 

169 """ 

170 Nicknames such as `floats`, `int`, `ints`, `all` 

171 can be used to describe multiple inputs for 

172 a signature. This function intreprets that. 

173 

174 .. runpython:: 

175 :showcode: 

176 

177 from mlprodict.npy.onnx_numpy_annotation import _NDArrayAlias 

178 for name in ['all', 'int', 'ints', 'floats', 'T']: 

179 print(name, _NDArrayAlias._process_type(name, {'T': 0}, 0)) 

180 """ 

181 if isinstance(dtypes, str): 

182 if ":" in dtypes: 

183 name, dtypes = dtypes.split(':') 

184 if name in mapped_types and dtypes != mapped_types[name]: 

185 raise RuntimeError( # pragma: no cover 

186 "Type name mismatch for '%s:%s' in %r." % ( 

187 name, dtypes, list(sorted(mapped_types)))) 

188 mapped_types[name] = (dtypes, index) 

189 if dtypes == "all": 

190 dtypes = all_dtypes 

191 elif dtypes in ("int", "int64"): 

192 dtypes = (numpy.int64, ) 

193 elif dtypes == "bool": 

194 dtypes = (numpy_bool, ) 

195 elif dtypes == "floats": 

196 dtypes = (numpy.float32, numpy.float64) 

197 elif dtypes == "ints": 

198 dtypes = (numpy.int32, numpy.int64) 

199 elif dtypes == "int64": 

200 dtypes = (numpy.int64, ) 

201 elif dtypes == "float32": 

202 dtypes = (numpy.float32, ) 

203 elif dtypes == "float64": 

204 dtypes = (numpy.float64, ) 

205 elif dtypes not in mapped_types: 

206 raise ValueError( # pragma: no cover 

207 "Unexpected shortcut for dtype %r." % dtypes) 

208 elif not isinstance(dtypes, tuple): 

209 dtypes = (dtypes, ) 

210 return dtypes 

211 

212 if isinstance(dtypes, (tuple, list)): 

213 insig = [_NDArrayAlias._process_type(dt, mapped_types, index + d) 

214 for d, dt in enumerate(dtypes)] 

215 return tuple(insig) 

216 

217 if dtypes in all_dtypes: 

218 return dtypes 

219 

220 raise NotImplementedError( # pragma: no cover 

221 "Unexpected input dtype %r." % dtypes) 

222 

223 def __repr__(self): 

224 "usual" 

225 return "%s(%r, %r, %r)" % ( 

226 self.__class__.__name__, self.dtypes, self.dtypes_out, 

227 self.n_optional) 

228 

229 def _get_output_types(self, key): 

230 """ 

231 Tries to infer output types. 

232 """ 

233 res = [] 

234 for i, o in enumerate(self.dtypes_out): 

235 if not isinstance(o, tuple): 

236 raise TypeError( # pragma: no cover 

237 "All outputs must be tuple, output %d is %r." 

238 "" % (i, o)) 

239 if (len(o) == 1 and (o[0] in all_dtypes or 

240 o[0] in (bool, numpy_bool, str, numpy_str))): 

241 res.append(o[0]) 

242 elif len(o) == 1 and o[0] in self.mapped_types: 

243 info = self.mapped_types[o[0]] 

244 res.append(key[info[1]]) 

245 elif key[0] in o: 

246 res.append(key[0]) 

247 else: 

248 raise RuntimeError( # pragma: no cover 

249 "Unable to guess output type for output %d, " 

250 "input types are %r, expected output is %r." 

251 "" % (i, key, o)) 

252 return tuple(res) 

253 

254 def get_inputs_outputs(self, args, kwargs, version): 

255 """ 

256 Returns the list of inputs, outputs. 

257 

258 :param args: list of arguments 

259 :param kwargs: list of optional arguments 

260 :param version: required version 

261 :return: *tuple(inputs, kwargs, outputs, optional)*, 

262 inputs and outputs are tuple, kwargs are the arguments, 

263 *optional* is the number of optional arguments 

264 """ 

265 if not isinstance(version, FctVersion): 

266 raise TypeError("Version must be of type 'FctVersion' not " 

267 "%s, version=%s." % (type(version), version)) 

268 if args == ['args', 'kwargs']: 

269 raise RuntimeError( # pragma: no cover 

270 "Issue with signature %r." % args) 

271 for k, v in kwargs.items(): 

272 if isinstance(v, type): 

273 raise RuntimeError( # pragma: no cover 

274 "Default value for argument %r must not be of type %r" 

275 "." % (k, v)) 

276 if (not self.n_variables and 

277 len(args) > len(self.dtypes)): 

278 raise RuntimeError( 

279 "Unexpected number of inputs version=%s.\n" 

280 "Given: args=%s dtypes=%s." % ( 

281 version, args, self.dtypes)) 

282 

283 def _possible_names(): 

284 yield 'y' 

285 yield 'z' # pragma: no cover 

286 yield 'o' # pragma: no cover 

287 for i in range(0, 10000): # pragma: no cover 

288 yield 'o%d' % i 

289 

290 new_kwargs = OrderedDict( 

291 (k, v) for k, v in zip(kwargs, version.kwargs or tuple())) 

292 if self.n_variables: 

293 # undefined number of inputs 

294 optional = 0 

295 else: 

296 optional = len(self.dtypes) - len(version.args) 

297 if optional > self.n_optional: 

298 raise RuntimeError( # pragma: no cover 

299 "Unexpected number of optional parameters %d, at most " 

300 "%d are expected, version=%s, args=%s, dtypes=%s." % ( 

301 optional, self.n_optional, version, args, self.dtypes)) 

302 optional = self.n_optional - optional 

303 

304 onnx_types = [k for k in version.args] 

305 inputs = list(zip(args[:len(version.args)], onnx_types)) 

306 if self.n_variables and len(inputs) < len(version.args): 

307 # Complete the list of inputs 

308 last_name = inputs[-1][0] 

309 while len(inputs) < len(onnx_types): 

310 inputs.append(('%s%d' % (last_name, len(inputs)), 

311 onnx_types[len(inputs)])) 

312 

313 key_out = self._get_output_types(version.args) 

314 onnx_types_out = key_out 

315 

316 names_out = [] 

317 names_in = set(inp[0] for inp in inputs) 

318 for _ in key_out: 

319 for name in _possible_names(): 

320 if name not in names_in and name not in names_out: 

321 name_out = name 

322 break 

323 names_out.append(name_out) 

324 names_in.add(name_out) 

325 

326 outputs = list(zip(names_out, onnx_types_out)) 

327 if optional < 0: 

328 raise RuntimeError( # pragma: no cover 

329 "optional cannot be negative %r (self.n_optional=%r, " 

330 "len(self.dtypes)=%r, len(inputs)=%r) " 

331 "names_in=%r, names_out=%r." % ( 

332 optional, self.n_optional, len(self.dtypes), 

333 len(inputs), names_in, names_out)) 

334 

335 if (not self.n_variables and 

336 len(inputs) + len(new_kwargs) > len(version)): 

337 raise RuntimeError( # pragma: no cover 

338 "Mismatch number of inputs and arguments for version=%s.\n" 

339 "Given: args=%s kwargs=%s.\n" 

340 "Returned: inputs=%s new_kwargs=%s.\n" % ( 

341 version, args, kwargs, inputs, new_kwargs)) 

342 if not self.n_variables and len(inputs) > len(self.dtypes): 

343 raise RuntimeError( # pragma: no cover 

344 "Mismatch number of inputs for version=%s.\n" 

345 "Given: args=%s.\n" 

346 "Expected: dtypes=%s\n" 

347 "Returned: inputs=%s.\n" % ( 

348 version, args, self.dtypes, inputs)) 

349 

350 return inputs, kwargs, outputs, optional, self.n_variables 

351 

352 def shape_calculator(self, dims): 

353 """ 

354 Returns expected dimensions given the input dimensions. 

355 """ 

356 if len(dims) == 0: 

357 return None 

358 res = [dims[0]] 

359 for _ in dims[1:]: 

360 res.append(None) 

361 return res 

362 

363 

364class NDArrayType(_NDArrayAlias): 

365 """ 

366 Shortcut to simplify signature description. 

367 

368 :param dtypes: input dtypes 

369 :param dtypes_out: output dtypes 

370 :param n_optional: number of optional parameters, 0 by default 

371 :param nvars: True if the function allows an infinite number of inputs, 

372 this is incompatible with parameter *n_optional*. 

373 

374 .. versionadded:: 0.6 

375 """ 

376 

377 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None, nvars=False): 

378 _NDArrayAlias.__init__(self, dtypes=dtypes, dtypes_out=dtypes_out, 

379 n_optional=n_optional, nvars=nvars) 

380 

381 

382class NDArrayTypeSameShape(NDArrayType): 

383 """ 

384 Shortcut to simplify signature description. 

385 

386 :param dtypes: input dtypes 

387 :param dtypes_out: output dtypes 

388 :param n_optional: number of optional parameters, 0 by default 

389 :param nvars: True if the function allows an infinite number of inputs, 

390 this is incompatible with parameter *n_optional*. 

391 

392 .. versionadded:: 0.6 

393 """ 

394 

395 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None, nvars=False): 

396 NDArrayType.__init__(self, dtypes=dtypes, dtypes_out=dtypes_out, 

397 n_optional=n_optional, nvars=nvars) 

398 

399 

400class NDArraySameType(NDArrayType): 

401 """ 

402 Shortcut to simplify signature description. 

403 

404 :param dtypes: input dtypes 

405 

406 .. versionadded:: 0.6 

407 """ 

408 

409 def __init__(self, dtypes=None): 

410 if dtypes is None: 

411 raise ValueError("dtypes cannot be None.") # pragma: no cover 

412 if isinstance(dtypes, str) and "_" in dtypes: 

413 raise ValueError( # pragma: no cover 

414 "dtypes cannot include '_' meaning two different types.") 

415 if isinstance(dtypes, tuple): 

416 raise ValueError( # pragma: no cover 

417 "dtypes must be a single type.") 

418 NDArrayType.__init__(self, dtypes=(dtypes, )) 

419 

420 def __repr__(self): 

421 "usual" 

422 return "%s(%r)" % ( 

423 self.__class__.__name__, self.dtypes) 

424 

425 

426class NDArraySameTypeSameShape(NDArraySameType): 

427 """ 

428 Shortcut to simplify signature description. 

429 

430 :param dtypes: input dtypes 

431 

432 .. versionadded:: 0.6 

433 """ 

434 

435 def __init__(self, dtypes=None): 

436 NDArraySameType.__init__(self, dtypes=dtypes)