Coverage for onnxcustom/utils/onnxruntime_helper.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Onnxruntime helper.
4"""
5from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
6 OrtDevice as C_OrtDevice, OrtValue as C_OrtValue)
9def provider_to_device(provider_name):
10 """
11 Converts provider into a device.
13 :param provider_name: provider name
14 :return: device name
16 .. runpython::
17 :showcode:
19 from onnxcustom.utils.onnxruntime_helper import provider_to_device
20 print(provider_to_device('CPUExecutionProvider'))
21 """
22 if provider_name == 'CPUExecutionProvider':
23 return 'cpu'
24 if provider_name == 'CUDAExecutionProvider':
25 return 'cuda'
26 raise ValueError(
27 "Unexpected value for provider_name=%r." % provider_name)
30def get_ort_device_type(device):
31 """
32 Converts device into device type.
34 :param device: string
35 :return: device type
36 """
37 if isinstance(device, str):
38 if device == 'cuda':
39 return C_OrtDevice.cuda()
40 if device == 'cpu':
41 return C_OrtDevice.cpu()
42 raise ValueError( # pragma: no cover
43 'Unsupported device type: %r.' % device)
44 if not hasattr(device, 'device_type'):
45 raise TypeError('Unsupported device type: %r.' % type(device))
46 device_type = device.device_type()
47 if device_type in ('cuda', 1):
48 return C_OrtDevice.cuda()
49 if device_type in ('cpu', 0):
50 return C_OrtDevice.cpu()
51 raise ValueError( # pragma: no cover
52 'Unsupported device type: %r.' % device_type)
55def get_ort_device(device):
56 """
57 Converts device into :epkg:`C_OrtDevice`.
59 :param device: any type
60 :return: :epkg:`C_OrtDevice`
62 Example:
64 ::
66 get_ort_device('cpu')
67 get_ort_device('gpu')
68 get_ort_device('cuda')
69 get_ort_device('cuda:0')
70 """
71 if isinstance(device, C_OrtDevice):
72 return device
73 if isinstance(device, str):
74 if device == 'cpu':
75 return C_OrtDevice(
76 C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)
77 if device in {'gpu', 'cuda:0', 'cuda', 'gpu:0'}:
78 return C_OrtDevice(
79 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0)
80 if device.startswith('gpu:'):
81 idx = int(device[4:])
82 return C_OrtDevice(
83 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx)
84 if device.startswith('cuda:'):
85 idx = int(device[5:])
86 return C_OrtDevice(
87 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx)
88 raise ValueError(
89 "Unable to interpret string %r as a device." % device)
90 raise TypeError( # pragma: no cover
91 "Unable to interpret type %r, (%r) as de device." % (
92 type(device), device))
95def ort_device_to_string(device):
96 """
97 Returns a string representing the device.
98 Opposite of function @see fn get_ort_device.
100 :param device: see :epkg:`C_OrtDevice`
101 :return: string
102 """
103 if not isinstance(device, C_OrtDevice):
104 raise TypeError(
105 "device must be of type C_OrtDevice not %r." % type(device))
106 ty = device.device_type()
107 if ty == C_OrtDevice.cpu():
108 sty = 'cpu'
109 elif ty == C_OrtDevice.cuda():
110 sty = 'cuda'
111 else:
112 raise NotImplementedError( # pragma: no cover
113 "Unable to guess device for %r and type=%r." % (device, ty))
114 idx = device.device_id()
115 if idx == 0:
116 return sty
117 return "%s:%d" % (sty, idx)
120def numpy_to_ort_value(arr, device=None):
121 """
122 Converts a numpy array to :epkg:`C_OrtValue`.
124 :param arr: numpy array
125 :param device: :epkg:`C_OrtDevice` or None for cpu
126 :return: :epkg:`C_OrtValue`
127 """
128 if device is None:
129 device = get_ort_device('cpu')
130 return C_OrtValue.ortvalue_from_numpy(arr, device)
133def device_to_providers(device):
134 """
135 Returns the corresponding providers for a specific device.
137 :param device: :epkg:`C_OrtDevice`
138 :return: providers
139 """
140 if isinstance(device, str):
141 device = get_ort_device(device)
142 if device.device_type() == device.cpu():
143 return ['CPUExecutionProvider']
144 if device.device_type() == device.cuda():
145 return ['CUDAExecutionProvider']
146 raise ValueError( # pragma: no cover
147 "Unexpected device %r." % device)