Coverage for mlprodict/npy/xop_variable.py: 94%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

173 statements  

1""" 

2@file 

3@brief Xop API to build onnx graphs. Inspired from :epkg:`sklearn-onnx`. 

4 

5.. versionadded:: 0.9 

6""" 

7import numpy 

8from onnx import ValueInfoProto 

9from onnx.helper import make_tensor_type_proto 

10from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 

11from onnx.defs import onnx_opset_version 

12from .. import __max_supported_opset__ 

13 

14 

15def max_supported_opset(): 

16 """ 

17 Returns the latest supported opset for the main domain. 

18 

19 .. runpython:: 

20 :showcode: 

21 

22 from mlprodict.npy.xop_variable import max_supported_opset 

23 print("max_supported_opset() returns", max_supported_opset()) 

24 """ 

25 return min(__max_supported_opset__, onnx_opset_version()) 

26 

27 

28def is_numpy_dtype(dtype): 

29 """ 

30 Tells if a dtype is a numpy dtype. 

31 

32 :param dtype: anything 

33 :return: boolean 

34 """ 

35 if isinstance(dtype, (list, dict, Variable)): 

36 return False 

37 if dtype in NP_TYPE_TO_TENSOR_TYPE: 

38 return True 

39 dt = numpy.dtype(dtype) 

40 if dt in NP_TYPE_TO_TENSOR_TYPE: 

41 return True 

42 return False 

43 

44 

45def numpy_type_prototype(dtype): 

46 """ 

47 Converts a numpy dtyp into a TensorProto dtype. 

48 

49 :param dtype: dtype 

50 :return: proto dtype 

51 """ 

52 if dtype in NP_TYPE_TO_TENSOR_TYPE: 

53 return NP_TYPE_TO_TENSOR_TYPE[dtype] 

54 dt = numpy.dtype(dtype) 

55 if dt in NP_TYPE_TO_TENSOR_TYPE: 

56 return NP_TYPE_TO_TENSOR_TYPE[dt] 

57 raise ValueError( # pragma: no cover 

58 "Unable to convert dtype %r into ProtoType." % dtype) 

59 

60 

61def guess_numpy_type(data_type): 

62 """ 

63 Guesses the corresponding numpy type based on data_type. 

64 """ 

65 if data_type in (numpy.float64, numpy.float32, numpy.int8, numpy.uint8, 

66 numpy.str_, numpy.bool_, numpy.int32, numpy.int64): 

67 return data_type 

68 if data_type == str: 

69 return numpy.str_ 

70 if data_type == bool: 

71 return numpy.bool_ 

72 name2numpy = { 

73 'FloatTensorType': numpy.float32, 

74 'DoubleTensorType': numpy.float64, 

75 'Int32TensorType': numpy.int32, 

76 'Int64TensorType': numpy.int64, 

77 'StringTensorType': numpy.str_, 

78 'BooleanTensorType': numpy.bool_, 

79 'Complex64TensorType': numpy.complex64, 

80 'Complex128TensorType': numpy.complex128, 

81 } 

82 cl_name = data_type.__class__.__name__ 

83 if cl_name in name2numpy: 

84 return name2numpy[cl_name] 

85 if hasattr(data_type, 'type'): 

86 return guess_numpy_type(data_type.type) 

87 raise NotImplementedError( # pragma: no cover 

88 "Unsupported data_type '{}'.".format(data_type)) 

89 

90 

91class Variable: 

92 """ 

93 An input or output to an ONNX graph. 

94 

95 :param name: name 

96 :param dtype: :epkg:`numpy` dtype (can be None) 

97 :param shape: shape (can be None) 

98 :param added_dtype: :epkg:`numpy` dtype specified at conversion type 

99 (can be None) 

100 :param added_shape: :epkg:`numpy` shape specified at conversion type 

101 (can be None) 

102 """ 

103 

104 def __init__(self, name, dtype=None, shape=None, added_dtype=None, 

105 added_shape=None): 

106 if (dtype is not None and isinstance( 

107 dtype, (int, Variable, tuple, numpy.ndarray))): 

108 raise TypeError( 

109 "Unexpected type %r for dtype." % type(dtype)) 

110 if (added_dtype is not None and isinstance( 

111 added_dtype, (int, Variable, tuple, numpy.ndarray))): 

112 raise TypeError( 

113 "Unexpected type %r for added_dtype." % type(added_dtype)) 

114 if shape is not None and not isinstance(shape, (tuple, list)): 

115 raise TypeError( 

116 "Unexpected type %r for shape." % type(shape)) 

117 if (added_shape is not None and not isinstance( 

118 added_shape, (tuple, list))): 

119 raise TypeError( 

120 "Unexpected type %r for added_shape." % type(added_shape)) 

121 

122 if isinstance(name, Variable): 

123 if (dtype is not None or shape is not None or 

124 added_dtype is not None or added_shape is not None): 

125 raise ValueError( # pragma: no cover 

126 "If name is a Variable, then all others attributes " 

127 "should be None.") 

128 

129 self.name_ = name.name_ 

130 self.dtype_ = name.dtype_ 

131 self.added_dtype_ = name.added_dtype_ 

132 self.shape_ = name.shape_ 

133 self.added_shape_ = name.added_shape_ 

134 else: 

135 if not isinstance(name, str): 

136 raise TypeError( # pragma: no cover 

137 "name must be a string not %r." % type(name)) 

138 

139 self.name_ = name 

140 self.dtype_ = dtype 

141 self.added_dtype_ = added_dtype 

142 self.shape_ = shape 

143 self.added_shape_ = added_shape 

144 

145 def to_skl2onnx(self, scope=None): 

146 """ 

147 Converts this instance into an instance of *Variable* 

148 from :epkg:`sklearn-onnx`. 

149 """ 

150 from skl2onnx.common._topology import Variable as skl2onnxVariable # delayed 

151 from skl2onnx.common.data_types import _guess_numpy_type # delayed 

152 inst = _guess_numpy_type(self.dtype, self.shape) 

153 var = skl2onnxVariable(self.name, self.name, type=inst, scope=scope) 

154 return var 

155 

156 @staticmethod 

157 def from_skl2onnx(var): 

158 """ 

159 Converts var from :epkg:`sklearn-onnx` into this class. 

160 """ 

161 return Variable(var.onnx_name, guess_numpy_type(var.type), 

162 shape=var.type.shape) 

163 

164 @property 

165 def name(self): 

166 "Returns the variable name (`self.name_`)." 

167 return self.name_ 

168 

169 @property 

170 def dtype(self): 

171 "Returns `self.dtype_`." 

172 return self.dtype_ 

173 

174 @property 

175 def shape(self): 

176 "Returns `self.shape_`." 

177 return self.shape_ 

178 

179 @property 

180 def proto_type(self): 

181 "Returns the proto type for `self.dtype_`." 

182 if self.dtype_ is None: 

183 return 0 

184 return numpy_type_prototype(self.dtype_) 

185 

186 @property 

187 def proto_added_type(self): 

188 "Returns the proto type for `self.added_dtype_` or `self.dtype_`." 

189 dt = self.added_dtype_ or self.dtype_ 

190 if dt is None: 

191 return 0 

192 return numpy_type_prototype(dt) 

193 

194 @property 

195 def proto_added_shape(self): 

196 "Returns the shape for `self.added_shape_` or `self.shape`." 

197 dt = self.added_shape_ or self.shape_ 

198 if dt is None: 

199 return None 

200 return list(dt) 

201 

202 def __repr__(self): 

203 "usual" 

204 kwargs = dict(dtype=self.dtype_, shape=self.shape_, 

205 added_dtype=self.added_dtype_, 

206 added_shape=self.added_shape_) 

207 kwargs = {k: v for k, v in kwargs.items() if v is not None} 

208 if len(kwargs) > 0: 

209 msg = ", " + ", ".join("%s=%r" % (k, v) for k, v in kwargs.items()) 

210 else: 

211 msg = '' 

212 return "%s(%r%s)" % ( 

213 self.__class__.__name__, self.name_, msg) 

214 

215 def is_named(self, name): 

216 "Tells the variable is named like that." 

217 if not isinstance(name, str): 

218 raise TypeError( # pragma: no cover 

219 "name is expected to be a string not %r." % type(name)) 

220 return self.name == name 

221 

222 def copy_add(self, dtype): 

223 """ 

224 Returns a copy of this variable with a new dtype. 

225 

226 :param dtype: added type 

227 :return: @see cl Variable 

228 """ 

229 if self.added_dtype_ is not None: 

230 raise RuntimeError( # pragma: no cover 

231 "Cannot copy as added_dtype is not None.") 

232 if isinstance(dtype, numpy.ndarray): 

233 dtype, shape = dtype.dtype, dtype.shape 

234 else: 

235 shape = None 

236 return Variable(self.name_, self.dtype_, self.shape_, dtype, shape) 

237 

238 def copy_merge(self, var): 

239 """ 

240 Merges information from both Variable. 

241 """ 

242 if not isinstance(var, Variable): 

243 return self.copy_add(var) 

244 res = Variable(self.name_, self.dtype_, 

245 self.shape_, self.added_dtype_, 

246 self.added_shape_) 

247 if self.added_dtype_ is None and var.dtype_ is not None: 

248 res.added_dtype_ = var.dtype_ 

249 if self.added_shape_ is None and var.shape_ is not None: 

250 res.added_shape_ = var.shape_ 

251 return res 

252 

253 def copy_name(self, name): 

254 """ 

255 Returns a copy with a new name. 

256 """ 

257 return Variable( 

258 name or self.name_, self.dtype_, 

259 self.shape_, self.added_dtype_, 

260 self.added_shape_) 

261 

262 def __eq__(self, other): 

263 """ 

264 Compares every attributes. 

265 """ 

266 if not isinstance(other, Variable): 

267 raise TypeError( 

268 "Unexpected type %r." % type(other)) 

269 if self.name != other.name: 

270 return False 

271 if self.shape_ != other.shape_: 

272 return False 

273 if self.dtype_ != other.dtype_: 

274 return False 

275 return True 

276 

277 def make_value_info(self): 

278 """ 

279 Converts the variable into `onnx.ValueInfoProto`. 

280 

281 :return: instance of `onnx.ValueInfoProto` 

282 """ 

283 value_info = ValueInfoProto() 

284 value_info.name = self.name 

285 tensor_type_proto = make_tensor_type_proto(self.proto_type, self.shape) 

286 value_info.type.CopyFrom(tensor_type_proto) # pylint: disable=E1101 

287 return value_info 

288 

289 @staticmethod 

290 def from_pb(obj): 

291 """ 

292 Creates a Variable from a protobuf object. 

293 

294 :param obj: initializer, tensor 

295 :return: @see cl Variable 

296 """ 

297 from ..onnx_tools.onnx2py_helper import from_pb 

298 name, ty, shape = from_pb(obj) 

299 return Variable(name, ty, shape=shape) 

300 

301 

302class NodeResultName: 

303 """ 

304 Defines a result name for a node. 

305 

306 :param node: node it comes from 

307 :param index: index of the output 

308 """ 

309 

310 def __init__(self, node, index): 

311 self.node = node 

312 self.index = index 

313 

314 def __repr__(self): 

315 "Usual" 

316 return "%s(%r, %r)" % (self.__class__.__name__, self.node, self.index) 

317 

318 def get_name(self): 

319 """ 

320 Returns a name from output_names or a suggestion for a name. 

321 """ 

322 if self.node is None: 

323 raise RuntimeError( # pragma: no cover 

324 "node must not be None.") 

325 if self.node.output_names is not None: 

326 return self.node.output_names[self.index].name 

327 cl = self.node.op_type.lower()[:3] 

328 return "out_%s_%d" % (cl, self.index) 

329 

330 

331class DetectedVariable: 

332 """ 

333 Wrapper around a @see cl Variable to detect inputs 

334 and outputs of a graph. 

335 

336 :param node: node where the variable was detected 

337 :param var: instance of @see cl Variable 

338 :param index: index, only used if it is an output 

339 """ 

340 

341 def __init__(self, node, var, index): 

342 if not isinstance(var, Variable): 

343 raise TypeError( # pragma: no cover 

344 "Unexpected type %r, it should be a Variable." 

345 "" % type(var)) 

346 self.node = node 

347 self.var = var 

348 self.index = index 

349 

350 @property 

351 def name(self): 

352 "Returns variable name." 

353 return self.var.name 

354 

355 def __repr__(self): 

356 "usual" 

357 sindex = ", %s" % self.index if self.index >= 0 else "" 

358 if self.node is None: 

359 return "%s(None, %r%s)" % ( 

360 self.__class__.__name__, self.var, sindex) 

361 return "%s(%s-%d, %r%s)" % ( 

362 self.__class__.__name__, self.node.__class__.__name__, 

363 id(self.node), self.var, sindex) 

364 

365 

366class InputDetectedVariable(DetectedVariable): 

367 """ 

368 Instance of @see cl DetectedVariable. 

369 Only for inputs. 

370 """ 

371 

372 def __init__(self, node, var): 

373 DetectedVariable.__init__(self, node, var, -1) 

374 

375 

376class OutputDetectedVariable(DetectedVariable): 

377 """ 

378 Instance of @see cl DetectedVariable. 

379 Only for outputs. 

380 """ 

381 pass