Coverage for mlprodict/onnxrt/ops_shape/_element_wise.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

40 statements  

1""" 

2@file 

3@brief Computes shape inference for element wise operators. 

4""" 

5from .shape_excs import ShapeInferenceException 

6from .shape_result import ShapeResult, OnnxKind 

7 

8 

9def _element_wise(known_shapes, node): 

10 """ 

11 Infers shape for an element wise operator. 

12 The function returns but updates *known_shapes*. 

13 

14 :param known_shapes: known shapes 

15 :param node: Onnx node 

16 :return: updated or not 

17 """ 

18 x = known_shapes[node.input[0]] 

19 y = known_shapes[node.input[1]] 

20 if x.mtype != OnnxKind.Tensor: 

21 raise ShapeInferenceException( # pragma: no cover 

22 "Result %r must be a tensor." % x) 

23 if y.mtype != OnnxKind.Tensor: 

24 raise ShapeInferenceException( # pragma: no cover 

25 "Result %r must be a tensor." % y) 

26 return known_shapes.update( 

27 node.output[0], ShapeResult.broadcast(x, y, name=node.output[0])) 

28 

29 

30def shape_add(known_shapes, node): 

31 "Infers shape for operator Add." 

32 return _element_wise(known_shapes, node) 

33 

34 

35def shape_and(known_shapes, node): 

36 "Infers shape for operator And." 

37 return _element_wise(known_shapes, node) 

38 

39 

40def shape_div(known_shapes, node): 

41 "Infers shape for operator Div." 

42 return _element_wise(known_shapes, node) 

43 

44 

45def shape_equal(known_shapes, node): 

46 "Infers shape for operator Equal." 

47 return _element_wise(known_shapes, node) 

48 

49 

50def shape_greater(known_shapes, node): 

51 "Infers shape for operator Greater." 

52 return _element_wise(known_shapes, node) 

53 

54 

55def shape_greaterorequal(known_shapes, node): 

56 "Infers shape for operator GreaterOrEqual." 

57 return _element_wise(known_shapes, node) 

58 

59 

60def shape_less(known_shapes, node): 

61 "Infers shape for operator Less." 

62 return _element_wise(known_shapes, node) 

63 

64 

65def shape_lessorequal(known_shapes, node): 

66 "Infers shape for operator LessOrEqual." 

67 return _element_wise(known_shapes, node) 

68 

69 

70def shape_max(known_shapes, node): 

71 "Infers shape for operator Max." 

72 return _element_wise(known_shapes, node) 

73 

74 

75def shape_min(known_shapes, node): 

76 "Infers shape for operator Min." 

77 return _element_wise(known_shapes, node) 

78 

79 

80def shape_mod(known_shapes, node): 

81 "Infers shape for operator Mod." 

82 return _element_wise(known_shapes, node) 

83 

84 

85def shape_mul(known_shapes, node): 

86 "Infers shape for operator Mul." 

87 return _element_wise(known_shapes, node) 

88 

89 

90def shape_or(known_shapes, node): 

91 "Infers shape for operator Or." 

92 return _element_wise(known_shapes, node) 

93 

94 

95def shape_pow(known_shapes, node): 

96 "Infers shape for operator Pow." 

97 return _element_wise(known_shapes, node) 

98 

99 

100def shape_sub(known_shapes, node): 

101 "Infers shape for operator Sub." 

102 return _element_wise(known_shapes, node) 

103 

104 

105def shape_xor(known_shapes, node): 

106 "Infers shape for operator Xor." 

107 return _element_wise(known_shapes, node)