Coverage for mlprodict/tools/onnx_inference_ort_helper.py: 35%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

23 statements  

1# pylint: disable=C0302 

2""" 

3@file 

4@brief Helpers for :epkg:`onnxruntime`. 

5""" 

6 

7 

8def get_ort_device(device): 

9 """ 

10 Converts device into :epkg:`C_OrtDevice`. 

11 

12 :param device: any type 

13 :return: :epkg:`C_OrtDevice` 

14 

15 Example: 

16 

17 :: 

18 

19 get_ort_device('cpu') 

20 get_ort_device('gpu') 

21 get_ort_device('cuda') 

22 get_ort_device('cuda:0') 

23 """ 

24 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611,W0611 

25 OrtDevice as C_OrtDevice) # delayed 

26 if isinstance(device, C_OrtDevice): 

27 return device 

28 if isinstance(device, str): 

29 if device == 'cpu': 

30 return C_OrtDevice( 

31 C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) 

32 if device in {'gpu', 'cuda:0', 'cuda', 'gpu:0'}: 

33 return C_OrtDevice( 

34 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) 

35 if device.startswith('gpu:'): 

36 idx = int(device[4:]) 

37 return C_OrtDevice( 

38 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) 

39 if device.startswith('cuda:'): 

40 idx = int(device[5:]) 

41 return C_OrtDevice( 

42 C_OrtDevice.cuda(), C_OrtDevice.default_memory(), idx) 

43 raise ValueError( # pragma: no cover 

44 "Unable to interpret string %r as a device." % device) 

45 raise TypeError( # pragma: no cover 

46 "Unable to interpret type %r, (%r) as de device." % ( 

47 type(device), device)) 

48 

49 

50def device_to_providers(device): 

51 """ 

52 Returns the corresponding providers for a specific device. 

53 

54 :param device: :epkg:`C_OrtDevice` 

55 :return: providers 

56 """ 

57 if isinstance(device, str): 

58 device = get_ort_device(device) 

59 if device.device_type() == device.cpu(): 

60 return ['CPUExecutionProvider'] 

61 if device.device_type() == device.cuda(): 

62 return ['CUDAExecutionProvider', 'CPUExecutionProvider'] 

63 raise ValueError( # pragma: no cover 

64 "Unexpected device %r." % device)