Coverage for mlprodict/onnxrt/ops_cpu/op_negative_log_likelihood_loss.py: 86%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

58 statements  

1""" 

2@file 

3@brief Runtime operator. 

4""" 

5import numpy 

6from ..shape_object import ShapeObject 

7from ._op import OpRun 

8 

9 

10def _compute_negative_log_likelihood_loss(x, target, weight=None, 

11 reduction=b'mean', ignore_index=None): 

12 """ 

13 Modified version of `softmaxcrossentropy.py 

14 <https://github.com/onnx/onnx/blob/main/onnx/backend/ 

15 test/case/node/negativeloglikelihoodloss.py>`_ to handle other type 

16 than float32. 

17 """ 

18 input_shape = x.shape 

19 if len(input_shape) == 1: 

20 raise RuntimeError("Unsupported shape %r." % (input_shape, )) 

21 

22 target_shape = target.shape 

23 N = input_shape[0] 

24 C = input_shape[1] 

25 

26 # initialize the positional weights when required 

27 gather_weight = None 

28 if weight is not None: 

29 # setting mode='clip' to deal with ignore_index > C or < 0 cases. 

30 # when the target value is > C or < 0, it doesn't matter which value we are 

31 # taking in gather_weight, since it will be set to 0 in the following if-block 

32 # use numpy.int32 to make it compatible with x86 machines 

33 gather_weight = numpy.take(weight, numpy.array( 

34 target, dtype=numpy.int32), mode='clip') 

35 # set `ignore_index`'s loss weight to 0. 

36 # The loss tensor will be multiplied by this weight tensor, 

37 # so `ingore_index`'s loss value will be eliminated. 

38 if ignore_index is not None: 

39 gather_weight = numpy.where( 

40 target == ignore_index, 0, gather_weight).astype(dtype=x.dtype) 

41 elif ignore_index != -1: 

42 gather_weight = numpy.where( 

43 target == ignore_index, 0, 1).astype(dtype=x.dtype) 

44 

45 # if input is 4-d and above, make it 3-d 

46 if len(input_shape) != 3: 

47 x = x.reshape((N, C, -1)) 

48 target = target.reshape((N, -1)) 

49 

50 # Get a dimension from the reshaped input. 

51 # If the original input shape is [N, C, H, W], 

52 # the D here should be H * W because we reshape 

53 # [N, C, H, W] to [N, C, H * W]. 

54 D = x.shape[2] 

55 neg_gather_element_input = numpy.zeros((N, D), dtype=x.dtype) 

56 for i in range(N): 

57 for d in range(D): 

58 if target[i][d] != ignore_index: 

59 neg_gather_element_input[i][d] = -x[i][target[i][d]][d] 

60 

61 loss = neg_gather_element_input 

62 

63 # if the input was 4-d or above reshape to the right shape 

64 if len(input_shape) != 3: 

65 loss = loss.reshape(target_shape) 

66 

67 # apply the weights when required 

68 if gather_weight is not None: 

69 loss = gather_weight * loss 

70 if reduction == b'mean': 

71 loss = loss.sum() / gather_weight.sum() 

72 return loss 

73 

74 if reduction == b'mean': 

75 loss = numpy.mean(loss) 

76 elif reduction == b'sum': 

77 loss = numpy.sum(loss) 

78 return (loss, ) 

79 

80 

81class NegativeLogLikelihoodLoss(OpRun): 

82 """ 

83 Python runtime for function *NegativeLogLikelihoodLoss*. 

84 """ 

85 

86 atts = {'reduction': b'mean', 'ignore_index': -1} 

87 

88 def __init__(self, onnx_node, desc=None, **options): 

89 OpRun.__init__(self, onnx_node, desc=desc, 

90 expected_attributes=NegativeLogLikelihoodLoss.atts, 

91 **options) 

92 

93 def _run(self, x, target, weight=None): # pylint: disable=W0221 

94 return _compute_negative_log_likelihood_loss( 

95 x, target, weight=weight, reduction=self.reduction, # pylint: disable=E1101 

96 ignore_index=self.ignore_index) # pylint: disable=E1101 

97 

98 def _infer_shapes(self, x, target, weight=None): # pylint: disable=W0221 

99 n_outputs = len(self.onnx_node.output) 

100 if n_outputs == 1: 

101 return (ShapeObject(None, dtype=x.dtype), ) 

102 return (ShapeObject(None, dtype=x.dtype), 

103 ShapeObject(None, dtype=x.dtype)) 

104 

105 def _infer_types(self, x, target, weight=None): # pylint: disable=W0221 

106 n_outputs = len(self.onnx_node.output) 

107 if n_outputs == 1: 

108 return (x.dtype, ) 

109 return (x.dtype, x.dtype) 

110 

111 def _infer_sizes(self, *args): # pylint: disable=W0221 

112 res = self.run(*args) 

113 return (dict(temp=0), ) + res