Coverage for mlprodict/onnxrt/ops_cpu/op_conv_transpose.py: 76%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ..shape_object import ShapeObjectFct
10from .op_conv_transpose_ import ( # pylint: disable=E0611,E0401
11 ConvTransposeFloat, ConvTransposeDouble)
14class ConvTranspose(OpRun):
16 atts = {'auto_pad': 'NOTSET', 'group': 1,
17 'dilations': [],
18 'kernel_shape': [],
19 'pads': [],
20 'strides': [],
21 'output_padding': [],
22 'output_shape': []}
24 def __init__(self, onnx_node, desc=None, **options):
25 OpRun.__init__(self, onnx_node, desc=desc,
26 expected_attributes=ConvTranspose.atts,
27 **options)
28 self._init()
30 def _init(self):
31 self.rt32_ = ConvTransposeFloat()
32 self.rt64_ = ConvTransposeDouble()
33 for rt in [self.rt32_, self.rt64_]:
34 rt.init(self.auto_pad,
35 numpy.array(self.dilations, dtype=numpy.int64),
36 self.group,
37 numpy.array(self.kernel_shape, dtype=numpy.int64),
38 numpy.array(self.pads, dtype=numpy.int64),
39 numpy.array(self.strides, dtype=numpy.int64),
40 numpy.array(self.output_padding, dtype=numpy.int64),
41 numpy.array(self.output_shape, dtype=numpy.int64))
43 def _run(self, X, W, B=None): # pylint: disable=W0221
44 if X.dtype == numpy.float32:
45 return (self.rt32_.compute(X, W, B), )
46 return (self.rt64_.compute(X, W, B), )
48 def _infer_shapes(self, X, W, B=None): # pylint: disable=W0221
50 def compute_shape(xshape, wshape, bshape):
51 xs = numpy.ones(xshape, dtype=numpy.float32)
52 ws = numpy.ones(wshape, dtype=numpy.float32)
53 bs = (numpy.ones(bshape, dtype=numpy.float32)
54 if bshape is not None else None)
55 res = self.rt32_.compute(xs, ws, bs)
56 return res.shape
58 return (ShapeObjectFct(
59 compute_shape, X, W, B, name="ConvTranspose", dtype=X.dtype), )
61 def _infer_types(self, X, W, B=None): # pylint: disable=W0221
62 return (X, )