Coverage for onnxcustom/utils/onnxruntime_helper.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

58 statements  

1""" 

2@file 

3@brief Onnxruntime helper. 

4""" 

5from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

6 OrtDevice as C_OrtDevice, OrtValue as C_OrtValue) 

7 

8 

9def provider_to_device(provider_name): 

10 """ 

11 Converts provider into a device. 

12 

13 :param provider_name: provider name 

14 :return: device name 

15 

16 .. runpython:: 

17 :showcode: 

18 

19 from onnxcustom.utils.onnxruntime_helper import provider_to_device 

20 print(provider_to_device('CPUExecutionProvider')) 

21 """ 

22 if provider_name == 'CPUExecutionProvider': 

23 return 'cpu' 

24 if provider_name == 'CUDAExecutionProvider': 

25 return 'cuda' 

26 raise ValueError( 

27 "Unexpected value for provider_name=%r." % provider_name) 

28 

29 

30def get_ort_device_type(device): 

31 """ 

32 Converts device into device type. 

33 

34 :param device: string 

35 :return: device type 

36 """ 

37 if isinstance(device, str): 

38 if device == 'cuda': 

39 return C_OrtDevice.cuda() 

40 if device == 'cpu': 

41 return C_OrtDevice.cpu() 

42 raise ValueError( # pragma: no cover 

43 'Unsupported device type: %r.' % device) 

44 if not hasattr(device, 'device_type'): 

45 raise TypeError('Unsupported device type: %r.' % type(device)) 

46 device_type = device.device_type() 

47 if device_type in ('cuda', 1): 

48 return C_OrtDevice.cuda() 

49 if device_type in ('cpu', 0): 

50 return C_OrtDevice.cpu() 

51 raise ValueError( # pragma: no cover 

52 'Unsupported device type: %r.' % device_type) 

53 

54 

55def get_ort_device(device): 

56 """ 

57 Converts device into :epkg:`C_OrtDevice`. 

58 

59 :param device: any type 

60 :return: :epkg:`C_OrtDevice` 

61 

62 Example: 

63 

64 :: 

65 

66 get_ort_device('cpu') 

67 get_ort_device('gpu') 

68 get_ort_device('cuda') 

69 get_ort_device('cuda:0') 

70 """ 

71 if isinstance(device, C_OrtDevice): 

72 return device 

73 if isinstance(device, str): 

74 if device == 'cpu': 

75 return C_OrtDevice( 

76 C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) 

77 if device in {'gpu', 'cuda:0', 'cuda', 'gpu:0'}: 

78 return C_OrtDevice( 

79 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) 

80 if device.startswith('gpu:'): 

81 idx = int(device[4:]) 

82 return C_OrtDevice( 

83 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) 

84 if device.startswith('cuda:'): 

85 idx = int(device[5:]) 

86 return C_OrtDevice( 

87 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) 

88 raise ValueError( 

89 "Unable to interpret string %r as a device." % device) 

90 raise TypeError( # pragma: no cover 

91 "Unable to interpret type %r, (%r) as de device." % ( 

92 type(device), device)) 

93 

94 

95def ort_device_to_string(device): 

96 """ 

97 Returns a string representing the device. 

98 Opposite of function @see fn get_ort_device. 

99 

100 :param device: see :epkg:`C_OrtDevice` 

101 :return: string 

102 """ 

103 if not isinstance(device, C_OrtDevice): 

104 raise TypeError( 

105 "device must be of type C_OrtDevice not %r." % type(device)) 

106 ty = device.device_type() 

107 if ty == C_OrtDevice.cpu(): 

108 sty = 'cpu' 

109 elif ty == C_OrtDevice.cuda(): 

110 sty = 'cuda' 

111 else: 

112 raise NotImplementedError( # pragma: no cover 

113 "Unable to guess device for %r and type=%r." % (device, ty)) 

114 idx = device.device_id() 

115 if idx == 0: 

116 return sty 

117 return "%s:%d" % (sty, idx) 

118 

119 

120def numpy_to_ort_value(arr, device=None): 

121 """ 

122 Converts a numpy array to :epkg:`C_OrtValue`. 

123 

124 :param arr: numpy array 

125 :param device: :epkg:`C_OrtDevice` or None for cpu 

126 :return: :epkg:`C_OrtValue` 

127 """ 

128 if device is None: 

129 device = get_ort_device('cpu') 

130 return C_OrtValue.ortvalue_from_numpy(arr, device) 

131 

132 

133def device_to_providers(device): 

134 """ 

135 Returns the corresponding providers for a specific device. 

136 

137 :param device: :epkg:`C_OrtDevice` 

138 :return: providers 

139 """ 

140 if isinstance(device, str): 

141 device = get_ort_device(device) 

142 if device.device_type() == device.cpu(): 

143 return ['CPUExecutionProvider'] 

144 if device.device_type() == device.cuda(): 

145 return ['CUDAExecutionProvider'] 

146 raise ValueError( # pragma: no cover 

147 "Unexpected device %r." % device)