Coverage for mlprodict/onnxrt/ops_shape/_element_unary.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

100 statements  

1""" 

2@file 

3@brief Computes shape inference for element wise operators with one input. 

4""" 

5import numpy 

6from .shape_excs import ShapeInferenceException 

7from .shape_result import OnnxKind 

8 

9 

10def _element_unary(known_shapes, node, dtype=None): 

11 """ 

12 Infers shape for an element wise operator. 

13 The function returns but updates *known_shapes*. 

14 

15 :param known_shapes: known shapes 

16 :param node: Onnx node 

17 :param dtype: None to keep the same type as input, 

18 not None to change it 

19 :return: updated or not 

20 """ 

21 x = known_shapes[node.input[0]] 

22 if x.mtype != OnnxKind.Tensor: 

23 raise ShapeInferenceException( # pragma: no cover 

24 "Result %r must be a tensor." % x) 

25 if dtype is None: 

26 return known_shapes.update(node.output[0], x.copy()) 

27 cp = x.copy() 

28 cp.dtype = dtype 

29 return known_shapes.update(node.output[0], cp) 

30 

31 

32def shape_abs(known_shapes, node): 

33 "Infers shape for operator Abs." 

34 return _element_unary(known_shapes, node) 

35 

36 

37def shape_acos(known_shapes, node): 

38 "Infers shape for operator Acos." 

39 return _element_unary(known_shapes, node) 

40 

41 

42def shape_acosh(known_shapes, node): 

43 "Infers shape for operator Acosh." 

44 return _element_unary(known_shapes, node) 

45 

46 

47def shape_asin(known_shapes, node): 

48 "Infers shape for operator Asin." 

49 return _element_unary(known_shapes, node) 

50 

51 

52def shape_asinh(known_shapes, node): 

53 "Infers shape for operator Asinh." 

54 return _element_unary(known_shapes, node) 

55 

56 

57def shape_atan(known_shapes, node): 

58 "Infers shape for operator Atan." 

59 return _element_unary(known_shapes, node) 

60 

61 

62def shape_atanh(known_shapes, node): 

63 "Infers shape for operator Atanh." 

64 return _element_unary(known_shapes, node) 

65 

66 

67def shape_castlike(known_shapes, node): 

68 "Infers shape for operator CastLike." 

69 x = known_shapes[node.input[0]] 

70 if x.mtype != OnnxKind.Tensor: 

71 raise ShapeInferenceException( # pragma: no cover 

72 "Result %r must be a tensor." % x) 

73 y = known_shapes[node.input[1]] 

74 if y.mtype != OnnxKind.Tensor: 

75 raise ShapeInferenceException( # pragma: no cover 

76 "Result %r must be a tensor." % y) 

77 cp = x.copy() 

78 cp.dtype = y.dtype 

79 return known_shapes.update(node.output[0], cp) 

80 

81 

82def shape_ceil(known_shapes, node): 

83 "Infers shape for operator Ceil." 

84 return _element_unary(known_shapes, node) 

85 

86 

87def shape_celu(known_shapes, node): 

88 "Infers shape for operator Celu." 

89 return _element_unary(known_shapes, node) 

90 

91 

92def shape_clip(known_shapes, node): 

93 "Infers shape for operator Clip." 

94 return _element_unary(known_shapes, node) 

95 

96 

97def shape_cos(known_shapes, node): 

98 "Infers shape for operator Cos." 

99 return _element_unary(known_shapes, node) 

100 

101 

102def shape_cosh(known_shapes, node): 

103 "Infers shape for operator Cosh." 

104 return _element_unary(known_shapes, node) 

105 

106 

107def shape_elu(known_shapes, node): 

108 "Infers shape for operator Elu." 

109 return _element_unary(known_shapes, node) 

110 

111 

112def shape_erf(known_shapes, node): 

113 "Infers shape for operator Erf." 

114 return _element_unary(known_shapes, node) 

115 

116 

117def shape_exp(known_shapes, node): 

118 "Infers shape for operator Exp." 

119 return _element_unary(known_shapes, node) 

120 

121 

122def shape_floor(known_shapes, node): 

123 "Infers shape for operator Floor." 

124 return _element_unary(known_shapes, node) 

125 

126 

127def shape_hardmax(known_shapes, node): 

128 "Infers shape for operator Hardmax." 

129 return _element_unary(known_shapes, node) 

130 

131 

132def shape_hardsigmoid(known_shapes, node): 

133 "Infers shape for operator HardSigmoid." 

134 return _element_unary(known_shapes, node) 

135 

136 

137def shape_identity(known_shapes, node): 

138 "Infers shape for operator Identity." 

139 return _element_unary(known_shapes, node) 

140 

141 

142def shape_isnan(known_shapes, node): 

143 "Infers shape for operator IsNan." 

144 return _element_unary(known_shapes, node, numpy.bool_) 

145 

146 

147def shape_isinf(known_shapes, node): 

148 "Infers shape for operator IsInf." 

149 return _element_unary(known_shapes, node, numpy.bool_) 

150 

151 

152def shape_leakyrelu(known_shapes, node): 

153 "Infers shape for operator LeakyRelu." 

154 return _element_unary(known_shapes, node) 

155 

156 

157def shape_log(known_shapes, node): 

158 "Infers shape for operator Log." 

159 return _element_unary(known_shapes, node) 

160 

161 

162def shape_logsoftmax(known_shapes, node): 

163 "Infers shape for operator LogSoftmax." 

164 return shape_softmax(known_shapes, node) 

165 

166 

167def shape_neg(known_shapes, node): 

168 "Infers shape for operator Neg." 

169 return _element_unary(known_shapes, node) 

170 

171 

172def shape_not(known_shapes, node): 

173 "Infers shape for operator Not." 

174 x = known_shapes[node.input[0]] 

175 if x.dtype != numpy.bool_: 

176 raise ShapeInferenceException( 

177 "Unexpected input type for operator Not %r (must be bool)." 

178 "" % x.dtype) 

179 return _element_unary(known_shapes, node) 

180 

181 

182def shape_reciprocal(known_shapes, node): 

183 "Infers shape for operator Reciprocal." 

184 return _element_unary(known_shapes, node) 

185 

186 

187def shape_relu(known_shapes, node): 

188 "Infers shape for operator Relu." 

189 return _element_unary(known_shapes, node) 

190 

191 

192def shape_round(known_shapes, node): 

193 "Infers shape for operator Round." 

194 return _element_unary(known_shapes, node) 

195 

196 

197def shape_selu(known_shapes, node): 

198 "Infers shape for operator Selu." 

199 return _element_unary(known_shapes, node) 

200 

201 

202def shape_sigmoid(known_shapes, node): 

203 "Infers shape for operator Sigmoid." 

204 return _element_unary(known_shapes, node) 

205 

206 

207def shape_sign(known_shapes, node): 

208 "Infers shape for operator Sigmoid." 

209 return _element_unary(known_shapes, node) 

210 

211 

212def shape_sin(known_shapes, node): 

213 "Infers shape for operator Sin." 

214 return _element_unary(known_shapes, node) 

215 

216 

217def shape_sinh(known_shapes, node): 

218 "Infers shape for operator Sinh." 

219 return _element_unary(known_shapes, node) 

220 

221 

222def shape_softmax(known_shapes, node): 

223 "Infers shape for operator Softmax." 

224 return _element_unary(known_shapes, node) 

225 

226 

227def shape_sqrt(known_shapes, node): 

228 "Infers shape for operator Sqrt." 

229 return _element_unary(known_shapes, node) 

230 

231 

232def shape_tan(known_shapes, node): 

233 "Infers shape for operator Tan." 

234 return _element_unary(known_shapes, node) 

235 

236 

237def shape_tanh(known_shapes, node): 

238 "Infers shape for operator Tanh." 

239 return _element_unary(known_shapes, node) 

240 

241 

242def shape_trilu(known_shapes, node): 

243 "Infers shape for operator Trilu." 

244 return _element_unary(known_shapes, node)