Coverage for mlprodict/onnxrt/ops_cpu/__init__.py: 87%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

47 statements  

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_cpu*. 

5""" 

6import textwrap 

7from ..excs import MissingOperatorError 

8from ._op import OpRunCustom 

9from ._op_list import __dict__ as d_op_list 

10 

11 

12_additional_ops = {} 

13 

14 

15def register_operator(cls, name=None, overwrite=True): 

16 """ 

17 Registers a new runtime operator. 

18 

19 @param cls class 

20 @param name by default ``cls.__name__``, 

21 or *name* if defined 

22 @param overwrite overwrite or raise an exception 

23 """ 

24 if name is None: 

25 name = cls.__name__ 

26 if name not in _additional_ops: 

27 _additional_ops[name] = cls 

28 elif not overwrite: 

29 raise RuntimeError( # pragma: no cover 

30 "Unable to overwrite existing operator '{}': {} " 

31 "by {}".format(name, _additional_ops[name], cls)) 

32 

33 

34def load_op(onnx_node, desc=None, options=None, runtime=None): 

35 """ 

36 Gets the operator related to the *onnx* node. 

37 

38 :param onnx_node: :epkg:`onnx` node 

39 :param desc: internal representation 

40 :param options: runtime options 

41 :param runtime: runtime 

42 :return: runtime class 

43 """ 

44 from ... import __max_supported_opset__ 

45 if desc is None: 

46 raise ValueError("desc should not be None.") # pragma no cover 

47 name = onnx_node.op_type 

48 opset = options.get('target_opset', None) if options is not None else None 

49 current_opset = __max_supported_opset__ 

50 chosen_opset = current_opset 

51 if opset == current_opset: 

52 opset = None 

53 if opset is not None: 

54 if not isinstance(opset, int): 

55 raise TypeError( # pragma no cover 

56 "opset must be an integer not {}".format(type(opset))) 

57 name_opset = name + "_" + str(opset) 

58 for op in range(opset, 0, -1): 

59 nop = name + "_" + str(op) 

60 if nop in d_op_list: 

61 name_opset = nop 

62 chosen_opset = op 

63 break 

64 else: 

65 name_opset = name 

66 

67 if name_opset in _additional_ops: 

68 cl = _additional_ops[name_opset] 

69 elif name in _additional_ops: 

70 cl = _additional_ops[name] 

71 elif name_opset in d_op_list: 

72 cl = d_op_list[name_opset] 

73 elif name in d_op_list: 

74 cl = d_op_list[name] 

75 else: 

76 raise MissingOperatorError( # pragma no cover 

77 "Operator '{}' from domain '{}' has no runtime yet. " 

78 "Available list:\n" 

79 "{} - {}".format( 

80 name, onnx_node.domain, 

81 "\n".join(sorted(_additional_ops)), 

82 "\n".join(textwrap.wrap( 

83 " ".join( 

84 _ for _ in sorted(d_op_list) 

85 if "_" not in _ and _ not in {'cl', 'clo', 'name'}))))) 

86 

87 if hasattr(cl, 'version_higher_than'): 

88 opv = min(current_opset, chosen_opset) 

89 if cl.version_higher_than > opv: 

90 # The chosen implementation does not support 

91 # the opset version, we need to downgrade it. 

92 if ('target_opset' in options and 

93 options['target_opset'] is not None): # pragma: no cover 

94 raise RuntimeError( 

95 "Supported version {} > {} (opset={}) required version, " 

96 "unable to find an implementation version {} found " 

97 "'{}'\n--ONNX--\n{}\n--AVAILABLE--\n{}".format( 

98 cl.version_higher_than, opv, opset, 

99 options['target_opset'], cl.__name__, onnx_node, 

100 "\n".join( 

101 _ for _ in sorted(d_op_list) 

102 if "_" not in _ and _ not in {'cl', 'clo', 'name'}))) 

103 options = options.copy() 

104 options['target_opset'] = current_opset 

105 return load_op(onnx_node, desc=desc, options=options) 

106 

107 if options is None: 

108 options = {} # pragma: no cover 

109 return cl(onnx_node, desc=desc, runtme=runtime, **options)