Coverage for mlprodict/tools/onnx_inference_ort_helper.py: 35%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=C0302
2"""
3@file
4@brief Helpers for :epkg:`onnxruntime`.
5"""
8def get_ort_device(device):
9 """
10 Converts device into :epkg:`C_OrtDevice`.
12 :param device: any type
13 :return: :epkg:`C_OrtDevice`
15 Example:
17 ::
19 get_ort_device('cpu')
20 get_ort_device('gpu')
21 get_ort_device('cuda')
22 get_ort_device('cuda:0')
23 """
24 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,W0611
25 OrtDevice as C_OrtDevice) # delayed
26 if isinstance(device, C_OrtDevice):
27 return device
28 if isinstance(device, str):
29 if device == 'cpu':
30 return C_OrtDevice(
31 C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)
32 if device in {'gpu', 'cuda:0', 'cuda', 'gpu:0'}:
33 return C_OrtDevice(
34 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0)
35 if device.startswith('gpu:'):
36 idx = int(device[4:])
37 return C_OrtDevice(
38 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx)
39 if device.startswith('cuda:'):
40 idx = int(device[5:])
41 return C_OrtDevice(
42 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx)
43 raise ValueError( # pragma: no cover
44 "Unable to interpret string %r as a device." % device)
45 raise TypeError( # pragma: no cover
46 "Unable to interpret type %r, (%r) as de device." % (
47 type(device), device))
50def device_to_providers(device):
51 """
52 Returns the corresponding providers for a specific device.
54 :param device: :epkg:`C_OrtDevice`
55 :return: providers
56 """
57 if isinstance(device, str):
58 device = get_ort_device(device)
59 if device.device_type() == device.cpu():
60 return ['CPUExecutionProvider']
61 if device.device_type() == device.cuda():
62 return ['CUDAExecutionProvider', 'CPUExecutionProvider']
63 raise ValueError( # pragma: no cover
64 "Unexpected device %r." % device)