Coverage for mlprodict/onnxrt/ops_cpu/op_if.py: 72%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

53 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from ...onnx_tools.onnx2py_helper import guess_dtype 

8from ..shape_object import ShapeObject 

9from ._op import OpRun 

10 

11 

12class If(OpRun): 

13 

14 atts = { 

15 'then_branch': None, 

16 'else_branch': None, 

17 } 

18 

19 def __init__(self, onnx_node, desc=None, **options): 

20 OpRun.__init__(self, onnx_node, desc=desc, 

21 expected_attributes=If.atts, 

22 **options) 

23 if not hasattr(self.then_branch, 'run'): 

24 raise RuntimeError( # pragma: no cover 

25 "Parameter 'then_branch' must have a method 'run', " 

26 "type {}.".format(type(self.then_branch))) 

27 if not hasattr(self.else_branch, 'run'): 

28 raise RuntimeError( # pragma: no cover 

29 "Parameter 'else_branch' must have a method 'run', " 

30 "type {}.".format(type(self.else_branch))) 

31 

32 self._run_meth_then = (self.then_branch.run_in_scan 

33 if hasattr(self.then_branch, 'run_in_scan') 

34 else self.then_branch.run) 

35 self._run_meth_else = (self.else_branch.run_in_scan 

36 if hasattr(self.else_branch, 'run_in_scan') 

37 else self.else_branch.run) 

38 

39 def _run(self, cond, named_inputs=None): # pylint: disable=W0221 

40 if named_inputs is None: 

41 named_inputs = {} 

42 if len(self.then_branch.input_names) > 0: 

43 if len(named_inputs) == 0: 

44 raise RuntimeError( # pragma: no cover 

45 "named_inputs is empty but the graph needs {}.".format( 

46 self.then_branch.input_names)) 

47 for k in self.then_branch.input_names: 

48 if k not in named_inputs: 

49 raise RuntimeError( # pragma: no cover 

50 "Unable to find named input '{}' in\n{}.".format( 

51 k, "\n".join(sorted(named_inputs)))) 

52 if len(self.else_branch.input_names) > 0: 

53 if len(named_inputs) == 0: 

54 raise RuntimeError( # pragma: no cover 

55 "named_inputs is empty but the graph needs {}.".format( 

56 self.then_branch.input_names)) 

57 for k in self.else_branch.input_names: 

58 if k not in named_inputs: 

59 raise RuntimeError( # pragma: no cover 

60 "Unable to find named input '{}' in\n{}.".format( 

61 k, "\n".join(sorted(named_inputs)))) 

62 

63 if len(cond.shape) > 0: 

64 if all(cond): 

65 outputs = self._run_meth_then(named_inputs) 

66 return tuple([outputs[name] for name in self.then_branch.output_names]) 

67 elif cond: 

68 outputs = self._run_meth_then(named_inputs) 

69 return tuple([outputs[name] for name in self.then_branch.output_names]) 

70 outputs = self._run_meth_else(named_inputs) 

71 return tuple([outputs[name] for name in self.else_branch.output_names]) 

72 

73 def _pick_shape(self, res, name): 

74 if name in res: 

75 return res[name] 

76 out = {o.name: o for o in self.then_branch.obj.graph.output} 

77 if name not in out: 

78 raise ValueError( # pragma: no cover 

79 "Unable to find name=%r in %r or %r." % ( 

80 name, list(sorted(res)), list(sorted(out)))) 

81 dt = out[name].type.tensor_type.elem_type 

82 return ShapeObject(None, guess_dtype(dt)) 

83 

84 def _infer_shapes(self, cond, named_inputs=None): # pylint: disable=W0221 

85 res = self.then_branch._set_shape_inference_runtime() 

86 return tuple([self._pick_shape(res, name) 

87 for name in self.then_branch.output_names]) 

88 

89 def _pick_type(self, res, name): 

90 if name in res: 

91 return res[name] 

92 out = {o.name: o for o in self.then_branch.obj.graph.output} 

93 if name not in out: 

94 raise ValueError( 

95 "Unable to find name=%r in %r or %r." % ( 

96 name, list(sorted(res)), list(sorted(out)))) 

97 dt = out[name].type.tensor_type.elem_type 

98 return guess_dtype(dt) 

99 

100 def _infer_types(self, cond, named_inputs=None): # pylint: disable=W0221 

101 res = self.then_branch._set_type_inference_runtime() 

102 return tuple([self._pick_type(res, name) 

103 for name in self.then_branch.output_names])