Coverage for mlprodict/onnxrt/ops_shape/__init__.py: 95%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

38 statements  

1""" 

2@file 

3@brief Shortcut to *ops_shape*. 

4""" 

5import textwrap 

6from onnx.onnx_cpp2py_export.defs import SchemaError # pylint: disable=E0401,E0611 

7from ...onnx_tools.onnx2py_helper import get_onnx_schema 

8from ._element_unary import ( 

9 shape_abs, shape_acos, shape_acosh, 

10 shape_asin, shape_asinh, shape_atan, shape_atanh, 

11 shape_castlike, shape_ceil, shape_celu, 

12 shape_clip, shape_cos, shape_cosh, 

13 shape_elu, shape_erf, shape_exp, shape_floor, 

14 shape_hardmax, shape_hardsigmoid, 

15 shape_identity, shape_isinf, shape_isnan, 

16 shape_leakyrelu, shape_log, shape_logsoftmax, 

17 shape_neg, shape_not, shape_reciprocal, shape_relu, shape_round, 

18 shape_selu, 

19 shape_sigmoid, shape_sign, shape_sin, shape_sinh, shape_softmax, 

20 shape_sqrt, shape_tan, shape_tanh, shape_trilu) 

21from ._element_wise import ( 

22 shape_add, shape_and, 

23 shape_div, 

24 shape_equal, 

25 shape_greater, shape_greaterorequal, 

26 shape_less, shape_lessorequal, 

27 shape_max, shape_min, shape_mod, shape_mul, 

28 shape_or, 

29 shape_pow, 

30 shape_sub, 

31 shape_xor) 

32from ._op_shape_op import shape_det 

33 

34 

35_shape_functions = { 

36 k: v for k, v in globals().items() if k.startswith("shape_") 

37} 

38 

39 

40count = [0] 

41 

42 

43def shape_dispatch(cache, known_shape, node, rt_class=None): 

44 """ 

45 Calls the corresponding fucntion for every node. 

46 

47 :param cache: cache used function 

48 :param known_shape: known_shape for all results 

49 :param node: onnx node 

50 :param rt_class: a node may be a predefined function in onnx, 

51 if no specific function is available, the predefined 

52 onnx definition is used and run through this runtime 

53 :return: was *known_shape* updated or not... 

54 """ 

55 key = node.domain, node.op_type 

56 fct_shape = None 

57 if key in cache: 

58 fct_shape = cache[key] 

59 else: 

60 op_type = "shape_" + node.op_type.lower() 

61 if op_type in _shape_functions: 

62 fct_shape = _shape_functions[op_type] 

63 cache[key] = fct_shape 

64 

65 if fct_shape is None and rt_class is not None: 

66 # check this operator is a predefined function in ONNX. 

67 try: 

68 onnx_schema = get_onnx_schema(node.op_type, node.domain) 

69 except SchemaError: 

70 onnx_schema = None 

71 if onnx_schema is not None and onnx_schema.has_function: 

72 sess = rt_class(onnx_schema.function_body) 

73 if len(node.input) != len(sess.input_names): 

74 raise RuntimeError( # pragma: no cover 

75 "node and function must have the same number of inputs, " 

76 "len(%r) != len(%r)." % ( 

77 node.input, sess.input_names)) 

78 if len(node.output) != len(sess.output_names): 

79 raise RuntimeError( # pragma: no cover 

80 "node and function must have the same number of outputs, " 

81 "len(%r) != len(%r)." % ( 

82 node.output, sess.output_names)) 

83 

84 def _shape_function(known_shape, node): 

85 inputs = {iname: known_shape[name] for name, iname in 

86 zip(node.input, sess.input_names)} 

87 outputs = sess.run(inputs) 

88 res = False 

89 for name, oname in zip(node.output, sess.output_names): 

90 r = known_shape.update(name, outputs[oname]) 

91 res = res or r 

92 return res 

93 

94 fct_shape = _shape_function 

95 cache[key] = fct_shape 

96 

97 if fct_shape is not None: 

98 return fct_shape(known_shape, node) 

99 

100 raise RuntimeError( # pragma: no cover 

101 "Unable to find a corresponding function for operator type %r " 

102 "domain=%r, looking for %r among\n%s" % ( 

103 node.op_type, node.domain, "shape_" + node.op_type.lower(), 

104 "\n".join(textwrap.wrap( 

105 " ".join(_ for _ in sorted(_shape_functions))))))