Coverage for mlprodict/onnxrt/ops_cpu/op_batch_normalization.py: 84%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun
12def _batchnorm_test_mode(x, s, bias, mean, var, epsilon=1e-5):
13 dims_x = len(x.shape)
14 dim_ones = (1,) * (dims_x - 2)
15 s = s.reshape(-1, *dim_ones)
16 bias = bias.reshape(-1, *dim_ones)
17 mean = mean.reshape(-1, *dim_ones)
18 var = var.reshape(-1, *dim_ones)
19 y = s * (x - mean) / numpy.sqrt(var + epsilon) + bias
20 return y.astype(x.dtype)
23def _batchnorm_training_mode(x, s, bias, mean, var, momentum=0.9,
24 epsilon=1e-5):
25 axis = tuple(numpy.delete(numpy.arange(len(x.shape)), 1))
26 saved_mean = x.mean(axis=axis)
27 saved_var = x.var(axis=axis)
28 output_mean = mean * momentum + saved_mean * (1 - momentum)
29 output_var = var * momentum + saved_var * (1 - momentum)
30 y = _batchnorm_test_mode(x, s, bias, saved_mean, saved_var,
31 epsilon=epsilon)
32 return (y.astype(x.dtype), saved_mean.astype(x.dtype),
33 saved_var.astype(x.dtype), output_mean.astype(x.dtype),
34 output_var.astype(x.dtype))
37class BatchNormalization_9(OpRun):
39 atts = {'epsilon': 1e-5, 'momentum': 0.9}
41 def __init__(self, onnx_node, desc=None, **options):
42 OpRun.__init__(self, onnx_node, desc=desc,
43 expected_attributes=BatchNormalization.atts,
44 **options)
46 def _run(self, x, scale, bias, mean, var): # pylint: disable=W0221
47 res = _batchnorm_test_mode(
48 x, scale, bias, mean, var, epsilon=self.epsilon)
49 return (res, )
51 def _infer_shapes(self, x, scale, bias, mean, var): # pylint: disable=W0221
52 return (x, )
54 def _infer_types(self, x, scale, bias, mean, var): # pylint: disable=W0221
55 return (x, )
57 def _infer_sizes(self, x, scale, bias, mean, var): # pylint: disable=W0221
58 res = self.run(x, scale, bias, mean, var)
59 return (dict(temp=x.size * x.dtype.itemsize * 2), ) + res
62class BatchNormalization_14(OpRun):
64 atts = {'epsilon': 1e-5, 'momentum': 0.9, 'training_mode': 0}
66 def __init__(self, onnx_node, desc=None, **options):
67 OpRun.__init__(self, onnx_node, desc=desc,
68 expected_attributes=BatchNormalization.atts,
69 **options)
71 def _run(self, x, scale, bias, mean, var): # pylint: disable=W0221
72 if self.training_mode == 0:
73 res = _batchnorm_test_mode(
74 x, scale, bias, mean, var, epsilon=self.epsilon)
75 return (res, )
76 res, __, _, output_mean, output_var = (
77 _batchnorm_training_mode(x, scale, bias, mean, var,
78 self.momentum, self.epsilon))
79 return res, output_mean, output_var
81 def _infer_shapes(self, x, scale, bias, mean, var): # pylint: disable=W0221
82 if self.training_mode == 0:
83 return (x, )
84 return (x, mean, var)
86 def _infer_types(self, x, scale, bias, mean, var): # pylint: disable=W0221
87 if self.training_mode == 0:
88 return (x, )
89 return (x, scale, bias, mean, var)
91 def _infer_sizes(self, x, scale, bias, mean, var): # pylint: disable=W0221
92 if self.training_mode == 0:
93 res = self.run(x, scale, bias, mean, var)
94 return (dict(temp=x.size * x.dtype.itemsize * 2), ) + res
95 res = self.run(x, scale, bias, mean, var)
96 return (dict(temp=x.size * x.dtype.itemsize * 4), ) + res
99if onnx_opset_version() >= 14:
100 BatchNormalization = BatchNormalization_14
101else: # pragma: no cover
102 BatchNormalization = BatchNormalization_9