Coverage for mlprodict/onnxrt/ops_cpu/op_unsqueeze.py: 87%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

53 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ..shape_object import ShapeObject 

10from ._op import OpRunUnaryNum, OpRun 

11 

12 

13class Unsqueeze_1(OpRunUnaryNum): 

14 

15 atts = {'axes': [], 'keepdims': 1} 

16 

17 def __init__(self, onnx_node, desc=None, **options): 

18 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

19 expected_attributes=Unsqueeze_1.atts, 

20 **options) 

21 if isinstance(self.axes, numpy.ndarray): 

22 self.axes = tuple(self.axes) 

23 elif self.axes in [[], tuple()]: 

24 self.axes = None 

25 elif isinstance(self.axes, list): 

26 self.axes = tuple(self.axes) 

27 

28 def _run(self, data): # pylint: disable=W0221 

29 if isinstance(self.axes, (tuple, list)): 

30 sq = data 

31 for a in self.axes: 

32 sq = numpy.expand_dims(sq, axis=a) 

33 else: 

34 raise RuntimeError( # pragma: no cover 

35 "axes cannot be None for operator Unsqueeze (Unsqueeze_1).") 

36 return (sq, ) 

37 

38 def _infer_shapes(self, x): # pylint: disable=W0221 

39 return (x.unsqueeze(axes=self.axes), ) 

40 

41 def _infer_types(self, x): # pylint: disable=W0221 

42 return (x, ) 

43 

44 def _infer_sizes(self, *args, **kwargs): 

45 res = self.run(*args, **kwargs) 

46 return (dict(temp=0), ) + res 

47 

48 

49class Unsqueeze_11(Unsqueeze_1): 

50 pass 

51 

52 

53class Unsqueeze_13(OpRun): 

54 

55 atts = {'keepdims': 1} 

56 

57 def __init__(self, onnx_node, desc=None, **options): 

58 OpRun.__init__(self, onnx_node, desc=desc, 

59 expected_attributes=Unsqueeze_13.atts, 

60 **options) 

61 self.axes = None 

62 

63 def _run(self, data, axes=None): # pylint: disable=W0221 

64 if axes is not None: 

65 if hasattr(axes, '__iter__') and len(axes.shape) > 0: 

66 sq = numpy.expand_dims(data, axis=tuple(axes)) 

67 else: 

68 sq = numpy.expand_dims(data, axis=axes) 

69 else: 

70 raise RuntimeError( # pragma: no cover 

71 "axes cannot be None for operator Unsqueeze (Unsqueeze_13).") 

72 return (sq, ) 

73 

74 def _infer_shapes(self, x, axes=None): # pylint: disable=W0221 

75 return (ShapeObject(None, dtype=x.dtype), ) 

76 

77 def _infer_types(self, x, axes=None): # pylint: disable=W0221 

78 return (x, ) 

79 

80 def _infer_sizes(self, *args, **kwargs): 

81 res = self.run(*args, **kwargs) 

82 return (dict(temp=0), ) + res 

83 

84 

85if onnx_opset_version() >= 13: 

86 Unsqueeze = Unsqueeze_13 

87elif onnx_opset_version() >= 11: 

88 Unsqueeze = Unsqueeze_11 

89else: 

90 Unsqueeze = Unsqueeze_1