Coverage for mlprodict/onnxrt/ops_whole/session.py: 96%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

70 statements  

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_whole*. 

5""" 

6import json 

7from io import BytesIO 

8import numpy 

9import onnx 

10 

11 

12class OnnxWholeSession: 

13 """ 

14 Runs the prediction for a single :epkg:`ONNX`, 

15 it lets the runtime handle the graph logic as well. 

16 

17 :param onnx_data: :epkg:`ONNX` model or data 

18 :param runtime: runtime to be used, mostly :epkg:`onnxruntime` 

19 :param runtime_options: runtime options 

20 :param device: device, a string `cpu`, `cuda`, `cuda:0`... 

21 

22 .. versionchanged:: 0.8 

23 Parameter *device* was added. 

24 """ 

25 

26 def __init__(self, onnx_data, runtime, runtime_options=None, device=None): 

27 if runtime not in ('onnxruntime1', 'onnxruntime1-cuda'): 

28 raise NotImplementedError( # pragma: no cover 

29 "runtime '{}' is not implemented.".format(runtime)) 

30 

31 from onnxruntime import ( # delayed 

32 InferenceSession, SessionOptions, RunOptions, 

33 GraphOptimizationLevel) 

34 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

35 Fail as OrtFail, InvalidGraph as OrtInvalidGraph, 

36 InvalidArgument as OrtInvalidArgument, 

37 NotImplemented as OrtNotImplemented, 

38 RuntimeException as OrtRuntimeException) 

39 

40 if hasattr(onnx_data, 'SerializeToString'): 

41 onnx_data = onnx_data.SerializeToString() 

42 if isinstance(runtime_options, SessionOptions): 

43 sess_options = runtime_options 

44 session_options = None 

45 runtime_options = None 

46 else: 

47 session_options = ( 

48 None if runtime_options is None 

49 else runtime_options.get('session_options', None)) 

50 self.runtime = runtime 

51 sess_options = session_options or SessionOptions() 

52 self.run_options = RunOptions() 

53 self.run_options.log_severity_level = 3 

54 self.run_options.log_verbosity_level = 1 

55 

56 if session_options is None: 

57 if runtime_options is not None: 

58 if runtime_options.get('disable_optimisation', False): 

59 sess_options.graph_optimization_level = ( # pragma: no cover 

60 GraphOptimizationLevel.ORT_ENABLE_ALL) 

61 if runtime_options.get('enable_profiling', True): 

62 sess_options.enable_profiling = True 

63 if runtime_options.get('log_severity_level', 2) != 2: 

64 v = runtime_options.get('log_severity_level', 2) 

65 sess_options.log_severity_level = v 

66 self.run_options.log_severity_level = v 

67 elif runtime_options is not None and 'enable_profiling' in runtime_options: 

68 raise RuntimeError( # pragma: no cover 

69 "session_options and enable_profiling cannot be defined at the " 

70 "same time.") 

71 elif runtime_options is not None and 'disable_optimisation' in runtime_options: 

72 raise RuntimeError( # pragma: no cover 

73 "session_options and disable_optimisation cannot be defined at the " 

74 "same time.") 

75 elif runtime_options is not None and 'log_severity_level' in runtime_options: 

76 raise RuntimeError( # pragma: no cover 

77 "session_options and log_severity_level cannot be defined at the " 

78 "same time.") 

79 providers = ['CPUExecutionProvider'] 

80 if runtime == 'onnxruntime1-cuda': 

81 providers = ['CUDAExecutionProvider'] + providers 

82 try: 

83 self.sess = InferenceSession(onnx_data, sess_options=sess_options, 

84 device=device, providers=providers) 

85 except (OrtFail, OrtNotImplemented, OrtInvalidGraph, 

86 OrtInvalidArgument, OrtRuntimeException, RuntimeError) as e: 

87 from ...tools.asv_options_helper import display_onnx 

88 raise RuntimeError( 

89 "Unable to create InferenceSession due to '{}'\n{}.".format( 

90 e, display_onnx(onnx.load(BytesIO(onnx_data))))) from e 

91 self.output_names = [_.name for _ in self.sess.get_outputs()] 

92 

93 def run(self, inputs): 

94 """ 

95 Computes the predictions. 

96 

97 @param inputs dictionary *{variable, value}* 

98 @return list of outputs 

99 """ 

100 v = next(iter(inputs.values())) 

101 if isinstance(v, (numpy.ndarray, dict)): 

102 try: 

103 return self.sess._sess.run( 

104 self.output_names, inputs, self.run_options) 

105 except ValueError as e: 

106 raise ValueError( 

107 "Issue running inference inputs=%r, expected inputs=%r." 

108 "" % ( 

109 list(sorted(inputs)), 

110 [i.name for i in self.sess.get_inputs()])) from e 

111 try: 

112 return self.sess._sess.run_with_ort_values( 

113 inputs, self.output_names, self.run_options) 

114 except RuntimeError: 

115 return self.sess._sess.run_with_ort_values( 

116 {k: v._get_c_value() for k, v in inputs.items()}, 

117 self.output_names, self.run_options) 

118 

119 @staticmethod 

120 def process_profiling(js): 

121 """ 

122 Flattens json returned by onnxruntime profiling. 

123 

124 :param js: json 

125 :return: list of dictionaries 

126 """ 

127 rows = [] 

128 for row in js: 

129 if 'args' in row and isinstance(row['args'], dict): 

130 for k, v in row['args'].items(): 

131 row['args_%s' % k] = v 

132 del row['args'] 

133 rows.append(row) 

134 return rows 

135 

136 def get_profiling(self): 

137 """ 

138 Returns the profiling informations. 

139 """ 

140 prof = self.sess.end_profiling() 

141 with open(prof, 'r') as f: 

142 content = f.read() 

143 js = json.loads(content) 

144 return OnnxWholeSession.process_profiling(js)