Coverage for mlprodict/onnxrt/ops_onnxruntime/_op.py: 96%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

134 statements  

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_onnxruntime*. 

5""" 

6import numpy 

7import onnx.defs 

8from onnx.helper import make_tensor 

9from onnx.onnx_cpp2py_export.shape_inference import InferenceError # pylint: disable=E0401,E0611 

10from ...tools.ort_wrapper import InferenceSession 

11from ...onnx_tools.onnx2py_helper import guess_proto_dtype 

12from ...onnx_tools.optim.graph_schema_helper import ( 

13 get_defined_inputs, get_defined_outputs, proto2vars) 

14 

15 

16_schemas = { 

17 schema.name: schema for schema in onnx.defs.get_all_schemas_with_history()} 

18 

19 

20class OpRunOnnxRuntime: 

21 """ 

22 Unique operator which calls :epkg:`onnxruntime` 

23 to compute predictions for one operator. 

24 """ 

25 

26 def __init__(self, onnx_node, desc=None, variables=None, 

27 dtype=None, runtime=None, **options): 

28 """ 

29 :param onnx_node: :epkg:`onnx` node 

30 :param desc: internal representation 

31 :param variables: registered variables created by previous operators 

32 :param dtype: float computation type 

33 :param options: runtime options 

34 :param runtime: `onnxruntime1`, `onnxruntime1-cuda`, ... 

35 """ 

36 self._provider = 'onnxruntime' 

37 self.onnx_node = onnx_node 

38 self.desc = desc 

39 self.runtime = runtime 

40 self._schema = _schemas.get(onnx_node.op_type, None) 

41 if desc is not None: 

42 if 'atts' in desc: 

43 for a, b in desc['atts'].items(): 

44 if not isinstance(b, dict) or 'value' not in b: 

45 raise ValueError( # pragma: no cover 

46 "Unexpected value {}.".format(b)) 

47 options[a] = b['value'] 

48 

49 self.options = options 

50 self.dtype = dtype 

51 self._init(variables) 

52 

53 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

54 InvalidArgument as OrtInvalidArgument) 

55 self.OrtInvalidArgument = OrtInvalidArgument 

56 

57 def _name_mapping(self, inputs): 

58 mapping = {} 

59 new_inputs = [] 

60 for name in inputs: 

61 if name in mapping: 

62 i = 0 

63 new_name = "{}_{}".format(name, i) 

64 while new_name in mapping: 

65 i += 1 # pragma: no cover 

66 new_name = "{}_{}".format(name, i) # pragma: no cover 

67 mapping[new_name] = name 

68 new_inputs.append(new_name) 

69 else: 

70 new_inputs.append(name) 

71 mapping[name] = name 

72 return mapping, new_inputs 

73 

74 def _guess_proto_type(self, dtype): 

75 return guess_proto_dtype(dtype) 

76 

77 def _init(self, variables=None): 

78 """ 

79 Initializes the node. 

80 

81 :param variables: registered variables created by previous operators 

82 

83 The current implementation for operator *Scan* 

84 only works for matrices. 

85 """ 

86 custom_nodes = self.options.get('nodes', None) 

87 if (custom_nodes is not None and 

88 self.onnx_node.op_type in custom_nodes): 

89 self.alg_class = custom_nodes[self.onnx_node.op_type] 

90 else: 

91 try: 

92 import mlprodict.onnx_conv.onnx_ops as alg0 

93 self.alg_class = getattr(alg0, 'Onnx' + self.onnx_node.op_type) 

94 except AttributeError: 

95 import skl2onnx.algebra.custom_ops as alg2 # delayed 

96 try: 

97 self.alg_class = getattr( 

98 alg2, 'Onnx' + self.onnx_node.op_type) 

99 except AttributeError: 

100 import skl2onnx.algebra.onnx_ops as alg # delayed 

101 self.alg_class = getattr( 

102 alg, 'Onnx' + self.onnx_node.op_type) 

103 

104 inputs = list(self.onnx_node.input) 

105 self.mapping, self.inputs = self._name_mapping(inputs) 

106 self.outputs = list(self.onnx_node.output) 

107 

108 options = self.options.copy() 

109 options.pop('nodes', None) 

110 target_opset = options.pop('target_opset', None) 

111 domain = options.pop('domain', None) 

112 disable_optimisation = options.pop('disable_optimisation', False) 

113 session_options = options.pop('session_options', False) 

114 ir_version = options.pop('ir_version', None) 

115 

116 if domain == '' and target_opset < 9: 

117 # target_opset should be >= 9 not {} for main domain. 

118 # We assume it was the case when the graph was created. 

119 pass 

120 

121 if self.onnx_node.op_type == 'ZipMap': 

122 from skl2onnx.common.data_types import ( # delayed 

123 DictionaryType, FloatTensorType, Int64TensorType, StringTensorType) 

124 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

125 op_version=target_opset, **options) 

126 inputs = get_defined_inputs( 

127 self.inputs, variables, dtype=self.dtype) 

128 name = (self.outputs[0] if len(self.outputs) == 1 

129 else self.inst_.expected_outputs[0][0]) 

130 otype = (Int64TensorType if 'classlabels_int64s' in options 

131 else StringTensorType) 

132 outvar = [(name, DictionaryType(otype([1]), FloatTensorType([1])))] 

133 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outvar) 

134 forced = True 

135 elif self.onnx_node.op_type == 'ConstantOfShape': 

136 for k in options: 

137 v = options[k] 

138 if isinstance(v, numpy.ndarray): 

139 options[k] = make_tensor( 

140 k, self._guess_proto_type(v.dtype), 

141 v.shape, v.tolist()) 

142 

143 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

144 op_version=target_opset, **options) 

145 inputs = get_defined_inputs( 

146 self.inputs, variables, dtype=self.dtype) 

147 try: 

148 self.onnx_ = self.inst_.to_onnx(inputs, target_opset=target_opset, 

149 domain=domain) 

150 if "dim_value: 0" in str(self.onnx_): 

151 raise RuntimeError( # pragma: no cover 

152 "Probable issue as one dimension is null.\n--\n{}".format( 

153 self.onnx_)) 

154 except AttributeError as e: # pragma: no cover 

155 # older version of skl2onnx 

156 self.onnx_ = self.inst_.to_onnx(inputs) 

157 if "dim_value: 0" in str(self.onnx_): 

158 raise RuntimeError( 

159 "Probable issue as one dimension is null.\n--\n{}".format( 

160 self.onnx_)) from e 

161 forced = False 

162 elif self.onnx_node.op_type == 'Scan': 

163 self.inst_ = self.alg_class( 

164 *self.inputs, output_names=self.outputs, 

165 op_version=target_opset, **options) 

166 inputs = get_defined_inputs( 

167 self.inputs, variables, dtype=self.dtype) 

168 outputs = get_defined_outputs( 

169 self.outputs, self.onnx_node, inputs, variables, 

170 dtype=self.dtype) 

171 inputs = [(name, cl.__class__([None, None])) 

172 for (name, cl) in inputs] 

173 outputs = [(name, cl.__class__([None, None])) 

174 for (name, cl) in outputs] 

175 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

176 target_opset=target_opset, 

177 domain=domain) 

178 if "dim_value: 0" in str(self.onnx_): 

179 raise RuntimeError( # pragma: no cover 

180 "Probable issue as one dimension is null.\n--\n{}".format( 

181 self.onnx_)) 

182 forced = True 

183 else: 

184 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

185 op_version=target_opset, domain=domain, 

186 **options) 

187 inputs = get_defined_inputs( 

188 self.inputs, variables, dtype=self.dtype, 

189 schema=self.alg_class.expected_inputs) 

190 

191 try: 

192 self.onnx_ = self.inst_.to_onnx( 

193 inputs, target_opset=target_opset, domain=domain) 

194 if "dim_value: 0" in str(self.onnx_): 

195 raise RuntimeError( # pragma: no cover 

196 "Probable issue as one dimension is null.\n--\n{}\n---\n{}".format( 

197 self.onnx_, inputs)) 

198 forced = False 

199 except (RuntimeError, ValueError, InferenceError) as eo: 

200 # Let's try again by forcing output types. 

201 forced = True 

202 outputs = get_defined_outputs( 

203 self.outputs, self.onnx_node, inputs, variables, 

204 dtype=self.dtype, schema=self.alg_class.expected_outputs, 

205 schema_inputs=self.alg_class.expected_inputs) 

206 try: 

207 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

208 target_opset=target_opset, 

209 domain=domain) 

210 except NotImplementedError as e: # pragma: no cover 

211 raise NotImplementedError( 

212 "Unable to instantiate node {} inputs={} " 

213 "self.inputs={} outputs={} variables={} " 

214 "dtype={} e={} eo={}".format( 

215 self.alg_class, inputs, self.inputs, 

216 outputs, variables, self.dtype, e, eo)) from e 

217 if "dim_value: 0" in str(self.onnx_): 

218 raise RuntimeError( # pragma: no cover 

219 "Probable issue as one dimension is null.\n--\n{}".format( 

220 self.onnx_)) from e 

221 

222 if len(self.onnx_.graph.output) != len(self.outputs): # pragma: no cover 

223 # Something is wrong, falls back to default plan. 

224 forced = True 

225 outputs = get_defined_outputs( 

226 self.outputs, self.onnx_node, inputs, variables, 

227 dtype=self.dtype, schema=self.alg_class.expected_outputs) 

228 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

229 target_opset=target_opset, 

230 domain=domain) 

231 if "dim_value: 0" in str(self.onnx_): 

232 raise RuntimeError( # pragma: no cover 

233 "Probable issue as one dimension is null.\n--\n{}".format( 

234 self.onnx_)) 

235 else: 

236 lo = list(self.onnx_.graph.output) 

237 outputs = proto2vars(lo) 

238 

239 from onnxruntime import ( # pylint: disable=E0611 

240 SessionOptions, RunOptions, GraphOptimizationLevel) 

241 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

242 Fail as OrtFail, InvalidGraph as OrtInvalidGraph, 

243 NotImplemented as OrtNotImplemented) 

244 

245 sess_options = session_options or SessionOptions() 

246 self.run_options = RunOptions() 

247 

248 if session_options is None: 

249 try: 

250 sess_options.session_log_severity_level = 3 

251 # sess_options.sessions_log_verbosity_level = 0 

252 except AttributeError: # pragma: no cover 

253 # onnxruntime not recent enough. 

254 pass 

255 try: 

256 self.run_options.run_log_severity_level = 3 

257 # self.run_options.run_log_verbosity_level = 0 

258 except AttributeError: # pragma: no cover 

259 # onnxruntime not recent enough. 

260 pass 

261 if disable_optimisation: 

262 sess_options.graph_optimization_level = ( # pragma: no cover 

263 GraphOptimizationLevel.ORT_DISABLE_ALL) 

264 elif disable_optimisation: 

265 raise RuntimeError( # pragma: no cover 

266 "session_options and disable_optimisation cannot be defined " 

267 "at the same time.") 

268 

269 if ir_version is not None: 

270 self.onnx_.ir_version = ir_version 

271 try: 

272 self.sess_ = InferenceSession( 

273 self.onnx_.SerializeToString(), sess_options=sess_options, 

274 runtime=self.runtime) 

275 except (RuntimeError, OrtNotImplemented, OrtInvalidGraph, OrtFail) as e: 

276 raise RuntimeError( 

277 "Unable to load node '{}' (output type was {}) inputs={} " 

278 "self.inputs={} self.onnx_node.input={} " 

279 "variables={} mapping={} " 

280 "expected_inputs={}\n{}".format( 

281 self.onnx_node.op_type, 

282 "guessed" if forced else "inferred", 

283 inputs, self.inputs, self.onnx_node.input, 

284 variables, self.mapping, 

285 self.alg_class.expected_inputs, 

286 self.onnx_)) from e 

287 self.typed_outputs_ = outputs 

288 

289 def run(self, *args, **kwargs): 

290 """ 

291 Should be overwritten. 

292 """ 

293 inputs = {name: val for name, val in zip(self.inputs, args)} 

294 

295 try: 

296 res = self.sess_.run(None, inputs, self.run_options) 

297 except (RuntimeError, self.OrtInvalidArgument) as e: # pragma: no cover 

298 dtypes = {k: v.dtype for k, v in inputs.items()} 

299 shapes = {k: v.shape for k, v in inputs.items()} 

300 exp = [_.name for _ in self.sess_.get_inputs()] 

301 exp_types = [_.type for _ in self.sess_.get_inputs()] 

302 raise RuntimeError( 

303 "Predictions failed. List of inputs: {}, class={}" 

304 "\ndtypes={}\nshapes={}\nexpected={}\nexpected={}\n" 

305 "exception={}\n--ONNX--\n{}".format( 

306 list(sorted(inputs)), self.alg_class, dtypes, 

307 shapes, exp, exp_types, e, self.onnx_)) from e 

308 return tuple(res) 

309 

310 def need_context(self): 

311 """ 

312 Tells the runtime if this node needs the context 

313 (all the results produced so far) as it may silently access 

314 one of them (operator Loop). 

315 The default answer is `False`. 

316 """ 

317 return False