Coverage for mlprodict/onnx_tools/optim/graph_schema_helper.py: 78%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

149 statements  

1""" 

2@file 

3@brief Functions to help guessing the final graph structure. 

4""" 

5import numpy 

6from onnx import TensorProto 

7 

8 

9def _guess_type(var): 

10 from skl2onnx.algebra.type_helper import _guess_type as skl2onnx__guess_type # delayed 

11 if isinstance(var, dict) and 'value' in var: 

12 return skl2onnx__guess_type(var['value']) # pragma: no cover 

13 return skl2onnx__guess_type(var) 

14 

15 

16def get_defined_inputs(input_names, variables=None, dtype=None, 

17 schema=None): 

18 """ 

19 Retrieves defined inputs in already declared variables 

20 bsed on their names. 

21 

22 @param input_names input names 

23 @param variables registered variables created 

24 by previous operators 

25 @param dtype float computational type 

26 @param schema defined inputs by schema (*expected_inputs*) 

27 @return typed inputs as ``tuple(name, type)`` 

28 """ 

29 from skl2onnx.common.data_types import ( # delayed 

30 DataType, FloatTensorType, DoubleTensorType) 

31 

32 def guess_type_variable(name, schema): 

33 if variables is None: 

34 if (schema is None or 

35 not isinstance(schema, (DataType, tuple))): 

36 return ( # pragma: no cover 

37 DoubleTensorType() if dtype == numpy.float64 else FloatTensorType()) 

38 return schema if isinstance(schema, DataType) else schema[1] 

39 if name in variables: 

40 ty = variables[name] 

41 if isinstance(ty, DataType): 

42 shape = ty.shape 

43 if 0 in shape: 

44 raise RuntimeError( # pragma: no cover 

45 "Shape cannot be empty: name='{}', var={}".format( 

46 name, ty)) 

47 return variables[name] 

48 if isinstance(ty, dict) and 'value' in ty: 

49 # constant 

50 arr = ty['value'] 

51 try: 

52 return _guess_type(arr) 

53 except RuntimeError as e: # pragma: no cover 

54 raise RuntimeError( 

55 "Unable to guess type of variable '{}' - {}." 

56 "".format(name, arr)) from e 

57 raise NotImplementedError( # pragma: no cover 

58 "Unable to guess type for '{}' form '{}'.".format( 

59 name, variables[name])) 

60 if isinstance(schema, (DataType, tuple)): 

61 sch = schema if isinstance(schema, DataType) else schema[1] 

62 if not isinstance(sch, str): 

63 return sch 

64 # Inputs. Let's assume it is a vector of floats. 

65 return DoubleTensorType() if dtype == numpy.float64 else FloatTensorType() 

66 

67 if schema is None or len(schema) < len(input_names): 

68 inputs = [(name, guess_type_variable(name, None)) 

69 for name in input_names] 

70 else: 

71 inputs = [(name, guess_type_variable(name, schema=sch)) 

72 for name, sch in zip(input_names, schema)] 

73 return inputs 

74 

75 

76def get_defined_outputs(outputs, onnx_node, typed_inputs=None, variables=None, 

77 dtype=None, schema=None, schema_inputs=None): 

78 """ 

79 Gets types of predefined outputs when they cannot be inferred. 

80 Some part of it should be automated based 

81 on type constraints. 

82 

83 :param outputs: requested outputs 

84 :param onnx_node: :epkg:`ONNX` node definition 

85 :param typed_inputs: known typed inputs of the node as `tuple(name, type)` 

86 :param variables: registered variables created by previous operators 

87 :param dtype: float computational type 

88 :param schema: defined outputs by schema (*expected_outputs*) 

89 :param schema_inputs: defined inputs by schema (*expected_inputs*) 

90 :return: typed outputs as ``tuple(name, type)`` 

91 """ 

92 from skl2onnx.common.data_types import ( # delayed 

93 DataType, 

94 FloatTensorType, SequenceType, DictionaryType, 

95 Int64Type, Int64TensorType, BooleanTensorType, 

96 DoubleTensorType, _guess_type_proto, _guess_type_proto_str) 

97 

98 if schema is None: 

99 ft = DoubleTensorType if dtype == numpy.float64 else FloatTensorType 

100 elif len(schema) != 1: 

101 raise ValueError( # pragma: no cover 

102 "schema should only contain one output not {}.".format(schema)) 

103 else: 

104 if isinstance(schema, DataType): 

105 ft = schema[0].__class__ 

106 else: 

107 ft = schema[0][1].__class__ 

108 

109 if onnx_node.op_type in {'ZipMap', 'ArgMin', 'ArgMax', 'Shape', 

110 'Greater', 'Less', 'Equal', 'TopK', 

111 'Cast', 'ArrayFeatureExtractor', 

112 'Reshape', 'Transpose', 'Scan', 

113 'ConstantOfShape'}: 

114 if onnx_node.op_type == "ZipMap": 

115 # ZipMap 

116 otype = SequenceType(DictionaryType( 

117 Int64Type(), ft())) 

118 outputs = [(name, otype) for name in outputs] 

119 elif (onnx_node.op_type in ("ArgMin", "ArgMax", 'Shape') and 

120 len(outputs) == 1): 

121 # ArgMin, ArgMax, Shape 

122 outputs = [(outputs[0], Int64TensorType())] 

123 elif (onnx_node.op_type in ("Greater", "Less", 'Equal') and 

124 len(outputs) == 1): 

125 # Greater, Less, Equal 

126 outputs = [(outputs[0], BooleanTensorType())] 

127 elif onnx_node.op_type == "TopK" and len(outputs) == 2: 

128 # TopK 

129 if len(typed_inputs) != 2: 

130 raise RuntimeError( # pragma: no cover 

131 "Wrong typed_inputs, got {}.".format(typed_inputs)) 

132 outputs = [(outputs[0], typed_inputs[0][1]), 

133 (outputs[1], Int64TensorType())] 

134 elif onnx_node.op_type == "Cast" and len(outputs) == 1: 

135 # Cast 

136 ttyp = _guess_type_proto(onnx_node.attribute[0].i, dims=None) 

137 outputs = [(outputs[0], ttyp)] 

138 elif onnx_node.op_type == "ArrayFeatureExtractor": 

139 # ArrayFeatureExtractor 

140 if len(typed_inputs) != 2: 

141 raise RuntimeError( # pragma: no cover 

142 "Wrong typed_inputs, got {}.".format(typed_inputs)) 

143 outputs = [(outputs[0], typed_inputs[0][1])] 

144 elif onnx_node.op_type in ('Reshape', 'Transpose'): 

145 # Reshape 

146 outputs = [(outputs[0], typed_inputs[0][1].__class__())] 

147 elif onnx_node.op_type == 'Scan': 

148 # Scan 

149 if len(outputs) != len(typed_inputs): 

150 raise RuntimeError( # pragma: no cover 

151 "Dimension mismatch, operator Scan should have " 

152 "the same number of inputs and outputs {} != {}" 

153 ".".format(len(outputs), len(typed_inputs))) 

154 outputs = [(o, t[1].__class__()) 

155 for o, t in zip(outputs, typed_inputs)] 

156 elif onnx_node.op_type == "ConstantOfShape": 

157 # ConstantOfShape 

158 outputs = [(outputs[0], ft())] 

159 elif 'Classifier' in onnx_node.op_type: 

160 # Good chance that's a classifier. 

161 outputs = [(outputs[0], Int64TensorType()), 

162 (outputs[1], ft())] 

163 else: 

164 if schema_inputs is not None and schema is not None: 

165 dt = {} 

166 for got, exp in zip(typed_inputs, schema_inputs): 

167 if isinstance(exp[1], str): 

168 dt[exp[1]] = got 

169 out = [] 

170 for i in range(len(outputs)): # pylint: disable=C0200 

171 o = outputs[i] 

172 if isinstance(o, str): 

173 exp = schema[i] 

174 if exp[1] in dt: 

175 out.append((o, dt[exp[1]][1].__class__())) 

176 else: 

177 nt = _guess_type_proto_str(exp[1], None) 

178 out.append((o, nt)) 

179 elif (isinstance(o, tuple) and 

180 (isinstance(o[1], str) or o[1] is None)): 

181 exp = schema[i] 

182 if exp[1] in dt: 

183 out.append((o[0], dt[exp[1]][1].__class__())) 

184 else: 

185 nt = _guess_type_proto_str(exp[1], None) 

186 out.append((o[0], nt)) 

187 else: 

188 out.append(o) 

189 outputs = out 

190 elif len(typed_inputs) == 1 and len(outputs) == 1: 

191 # Default case 

192 # Assuming the only output is the same as the only input. 

193 outputs = [(outputs[0], typed_inputs[0][1])] 

194 else: 

195 # Default 

196 outputs = [(name, ft()) for name in outputs] 

197 

198 for name, typ in outputs: 

199 if typ in ('T', None, '', 'I'): 

200 raise NotImplementedError( # pragma: no cover 

201 "Undefined output type: %r (outputs=%r, typed_inputs=%r, " 

202 "dtype=%r, schema=%r, schema_inputs=%r, onnx_node=%r, " 

203 "variables=%r)." % ( 

204 typ, outputs, typed_inputs, dtype, 

205 schema, schema_inputs, onnx_node, variables)) 

206 if not isinstance(name, str): 

207 raise NotImplementedError( # pragma: no cover 

208 "Undefined output type: %r (outputs=%r, typed_inputs=%r, " 

209 "dtype=%r, schema=%r, schema_inputs=%r, onnx_node=%r, " 

210 "variables=%r)." % ( 

211 typ, outputs, typed_inputs, dtype, 

212 schema, schema_inputs, onnx_node, variables)) 

213 return outputs 

214 

215 

216def proto2vars(values): 

217 """ 

218 Converts proto values to Variables. 

219 """ 

220 from skl2onnx.common.data_types import ( # delayed 

221 FloatTensorType, SequenceType, DictionaryType, 

222 Int64Type, Int64TensorType, BooleanTensorType, 

223 Int32TensorType, DoubleTensorType, FloatType, 

224 StringTensorType, Float16TensorType) 

225 

226 def ptype2vttype(it, shape): 

227 if it == TensorProto.FLOAT: # pylint: disable=E1101 

228 return FloatTensorType(shape) 

229 if it == TensorProto.DOUBLE: # pylint: disable=E1101 

230 return DoubleTensorType(shape) 

231 if it == TensorProto.INT64: # pylint: disable=E1101 

232 return Int64TensorType(shape) 

233 if it == TensorProto.INT32: # pylint: disable=E1101 

234 return Int32TensorType(shape) 

235 if it == TensorProto.BOOL: # pylint: disable=E1101 

236 return BooleanTensorType(shape) 

237 if it == TensorProto.STRING: # pylint: disable=E1101 

238 return StringTensorType(shape) 

239 if Float16TensorType is None: 

240 if it == TensorProto.FLOAT16: # pylint: disable=E1101 

241 return Float16TensorType(shape) 

242 raise NotImplementedError( # pragma: no cover 

243 "Unrecognized proto type {} with shape {}".format(it, shape)) 

244 

245 def ptype2vtype(it): 

246 if it == TensorProto.FLOAT: # pylint: disable=E1101 

247 return FloatType() 

248 if it == TensorProto.INT64: # pylint: disable=E1101 

249 return Int64Type() 

250 raise NotImplementedError( # pragma: no cover 

251 "Unrecognized proto type {}".format(it)) 

252 

253 res = [] 

254 for v_ in values: 

255 v = v_ 

256 name = v.name if hasattr(v, 'name') else None 

257 if hasattr(v, 'type') and str(v.type) != '': 

258 t = v.type 

259 v = proto2vars([t])[0][1] 

260 elif hasattr(v, 'sequence_type') and str(v.sequence_type) != '': 

261 subtype = proto2vars([v.sequence_type.elem_type])[0][1] 

262 v = SequenceType(subtype) 

263 elif hasattr(v, 'tensor_type') and str(v.tensor_type) != '': 

264 tt = v.tensor_type 

265 el = tt.elem_type 

266 shape = tt.shape 

267 dim = shape.dim 

268 if len(dim) == 0: 

269 shape = [] 

270 else: 

271 shape = [dim[i].dim_value for i in range(len(dim))] 

272 v = ptype2vttype(el, shape) 

273 elif hasattr(v, 'map_type') and str(v.map_type) != '': 

274 mt = v.map_type 

275 keyt = ptype2vtype(mt.key_type) 

276 valt = proto2vars([mt.value_type])[0][1] 

277 v = DictionaryType(keyt, valt) 

278 else: 

279 raise RuntimeError( # pragma: no cover 

280 "Unable to build a variable from {}.".format(v)) 

281 if v.shape is not None and 0 in v.shape: 

282 # Replaces 0 by None 

283 new_shape = tuple(None if d == 0 else d for d in v.shape) 

284 if new_shape in ((None, ), None): 

285 v = v.__class__() 

286 else: 

287 v = v.__class__(new_shape) 

288 if v.shape is not None and 0 in v.shape: 

289 raise RuntimeError( # pragma: no cover 

290 "Shape cannot be empty: '{}': {}.".format( 

291 name, v_)) 

292 res.append((name, v)) 

293 return res