Coverage for mlprodict/tools/zoo.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

99 statements  

1""" 

2@file 

3@brief Tools to test models from the :epkg:`ONNX Zoo`. 

4 

5.. versionadded:: 0.6 

6""" 

7import os 

8import urllib.request 

9from collections import OrderedDict 

10import numpy 

11from onnx import TensorProto, numpy_helper 

12try: 

13 from .ort_wrapper import InferenceSession 

14except ImportError: 

15 from mlprodict.tools.ort_wrapper import InferenceSession 

16 

17 

18def short_list_zoo_models(): 

19 """ 

20 Returns a short list from :epkg:`ONNX Zoo`. 

21 

22 :return: list of dictionaries. 

23 

24 .. runpython:: 

25 :showcode: 

26 :warningout: DeprecationWarning 

27 

28 import pprint 

29 from mlprodict.tools.zoo import short_list_zoo_models 

30 pprint.pprint(short_list_zoo_models()) 

31 """ 

32 return [ 

33 dict(name="mobilenet", 

34 model="https://github.com/onnx/models/raw/main/vision/" 

35 "classification/mobilenet/model/mobilenetv2-7.tar.gz"), 

36 dict(name="resnet18", 

37 model="https://github.com/onnx/models/raw/main/vision/" 

38 "classification/resnet/model/resnet18-v1-7.tar.gz"), 

39 dict(name="squeezenet", 

40 model="https://github.com/onnx/models/raw/main/vision/" 

41 "classification/squeezenet/model/squeezenet1.0-9.tar.gz", 

42 folder="squeezenet"), 

43 dict(name="densenet121", 

44 model="https://github.com/onnx/models/raw/main/vision/" 

45 "classification/densenet-121/model/densenet-9.tar.gz", 

46 folder="densenet121"), 

47 dict(name="inception2", 

48 model="https://github.com/onnx/models/raw/main/vision/" 

49 "classification/inception_and_googlenet/inception_v2/" 

50 "model/inception-v2-9.tar.gz"), 

51 dict(name="shufflenet", 

52 model="https://github.com/onnx/models/raw/main/vision/" 

53 "classification/shufflenet/model/shufflenet-9.tar.gz"), 

54 dict(name="efficientnet-lite4", 

55 model="https://github.com/onnx/models/raw/main/vision/" 

56 "classification/efficientnet-lite4/model/" 

57 "efficientnet-lite4-11.tar.gz"), 

58 ] 

59 

60 

61def _download_url(url, output_path, name, verbose=False): 

62 if verbose: # pragma: no cover 

63 from tqdm import tqdm 

64 

65 class DownloadProgressBar(tqdm): 

66 "progress bar hook" 

67 

68 def update_to(self, b=1, bsize=1, tsize=None): 

69 "progress bar hook" 

70 if tsize is not None: 

71 self.total = tsize 

72 self.update(b * bsize - self.n) 

73 

74 with DownloadProgressBar(unit='B', unit_scale=True, 

75 miniters=1, desc=name) as t: 

76 urllib.request.urlretrieve( 

77 url, filename=output_path, reporthook=t.update_to) 

78 else: 

79 urllib.request.urlretrieve(url, filename=output_path) 

80 

81 

82def load_data(folder): 

83 """ 

84 Restores protobuf data stored in a folder. 

85 

86 :param folder: folder 

87 :return: dictionary 

88 """ 

89 res = OrderedDict() 

90 res['in'] = OrderedDict() 

91 res['out'] = OrderedDict() 

92 files = os.listdir(folder) 

93 for name in files: 

94 noext, ext = os.path.splitext(name) 

95 if ext == '.pb': 

96 data = TensorProto() 

97 with open(os.path.join(folder, name), 'rb') as f: 

98 data.ParseFromString(f.read()) 

99 if noext.startswith('input'): 

100 res['in'][noext] = numpy_helper.to_array(data) 

101 elif noext.startswith('output'): 

102 res['out'][noext] = numpy_helper.to_array(data) 

103 else: 

104 raise ValueError( # pragma: no cover 

105 "Unable to guess anything about %r." % noext) 

106 

107 return res 

108 

109 

110def download_model_data(name, model=None, cache=None, verbose=False): 

111 """ 

112 Downloads a model and returns a link to the local 

113 :epkg:`ONNX` file and data which can be used as inputs. 

114 

115 :param name: model name (see @see fn short_list_zoo_models) 

116 :param model: url or empty to get the default value 

117 returned by @see fn short_list_zoo_models) 

118 :param cache: folder to cache the downloaded data 

119 :param verbose: display a progress bar 

120 :return: local onnx file, input data 

121 """ 

122 suggested_folder = None 

123 if model is None: 

124 model_list = short_list_zoo_models() 

125 for mod in model_list: 

126 if mod['name'] == name: 

127 model = mod['model'] 

128 if 'folder' in mod: # pylint: disable=R1715 

129 suggested_folder = mod['folder'] 

130 break 

131 if model is None: 

132 raise ValueError( 

133 "Unable to find a default value for name=%r." % name) 

134 

135 # downloads 

136 last_name = model.split('/')[-1] 

137 if cache is None: 

138 cache = os.path.abspath('.') # pragma: no cover 

139 dest = os.path.join(cache, last_name) 

140 if not os.path.exists(dest): 

141 _download_url(model, dest, name, verbose=verbose) 

142 size = os.stat(dest).st_size 

143 if size < 2 ** 20: # pragma: no cover 

144 os.remove(dest) 

145 raise ConnectionError( 

146 "Unable to download model from %r." % model) 

147 

148 outtar = os.path.splitext(dest)[0] 

149 if not os.path.exists(outtar): 

150 from pyquickhelper.filehelper.compression_helper import ( 

151 ungzip_files) 

152 ungzip_files(dest, unzip=False, where_to=cache, remove_space=False) 

153 

154 onnx_file = os.path.splitext(outtar)[0] 

155 if not os.path.exists(onnx_file): 

156 from pyquickhelper.filehelper.compression_helper import ( 

157 untar_files) 

158 untar_files(outtar, where_to=cache) 

159 

160 if suggested_folder is not None: 

161 fold_onnx = [suggested_folder] 

162 else: 

163 fold_onnx = [onnx_file, onnx_file.split('-')[0], 

164 '-'.join(onnx_file.split('-')[:-1]), 

165 '-'.join(onnx_file.split('-')[:-1]).replace('-', '_')] 

166 fold_onnx_ok = [_ for _ in fold_onnx if os.path.exists(_)] 

167 if len(fold_onnx_ok) != 1: 

168 raise FileNotFoundError( # pragma: no cover 

169 "Unable to find an existing folder among %r." % fold_onnx) 

170 onnx_file = fold_onnx_ok[0] 

171 

172 onnx_files = [_ for _ in os.listdir(onnx_file) if _.endswith(".onnx")] 

173 if len(onnx_files) != 1: 

174 raise FileNotFoundError( # pragma: no cover 

175 "Unable to find any onnx file in %r." % onnx_files) 

176 final_onnx = os.path.join(onnx_file, onnx_files[0]) 

177 

178 # data 

179 data = [_ for _ in os.listdir(onnx_file) 

180 if os.path.isdir(os.path.join(onnx_file, _))] 

181 examples = OrderedDict() 

182 for f in data: 

183 examples[f] = load_data(os.path.join(onnx_file, f)) 

184 

185 return final_onnx, examples 

186 

187 

188def verify_model(onnx_file, examples, runtime=None, abs_tol=5e-4, 

189 verbose=0, fLOG=None): 

190 """ 

191 Verifies a model. 

192 

193 :param onnx_file: ONNX file 

194 :param examples: list of examples to verify 

195 :param runtime: a runtime to use 

196 :param abs_tol: error tolerance when checking the output 

197 :param verbose: verbosity level for for runtime other than 

198 `'onnxruntime'` 

199 :param fLOG: logging function when `verbose > 0` 

200 :return: errors for every sample 

201 """ 

202 if runtime in ('onnxruntime', 'onnxruntime-cuda'): 

203 sess = InferenceSession(onnx_file, runtime=runtime) 

204 meth = lambda data, s=sess: s.run(None, data) 

205 names = [p.name for p in sess.get_inputs()] 

206 onames = list(range(len(sess.get_outputs()))) 

207 else: 

208 def _lin_(sess, data, names): 

209 r = sess.run(data, verbose=verbose, fLOG=fLOG) 

210 return [r[n] for n in names] 

211 

212 from ..onnxrt import OnnxInference 

213 sess = OnnxInference( 

214 onnx_file, runtime=runtime, 

215 runtime_options=dict(log_severity_level=3)) 

216 names = sess.input_names 

217 onames = sess.output_names 

218 meth = lambda data, s=sess, ns=onames: _lin_(s, data, ns) 

219 

220 rows = [] 

221 for index, (name, data_inout) in enumerate(examples.items()): 

222 data = data_inout["in"] 

223 if len(data) != len(names): 

224 raise RuntimeError( # pragma: no cover 

225 "Mismathed number of inputs %d != %d\ninputs: %r\nmodel: %r." 

226 "" % (len(data), len(names), list(sorted(data)), names)) 

227 inputs = {n: data[v] for n, v in zip(names, data)} 

228 outputs = meth(inputs) 

229 expected = data_inout['out'] 

230 if len(outputs) != len(onames): 

231 raise RuntimeError( # pragma: no cover 

232 "Number of outputs %d is != expected outputs %d." % ( 

233 len(outputs), len(onames))) 

234 for i, (output, expect) in enumerate(zip(outputs, expected.items())): 

235 if output.shape != expect[1].shape: 

236 raise ValueError( # pragma: no cover 

237 "Shape mismatch got %r != expected %r." % ( 

238 output.shape, expect[1].shape)) 

239 diff = numpy.abs(output - expect[1]).ravel() 

240 absolute = diff.max() 

241 relative = absolute / numpy.median(diff) if absolute > 0 else 0. 

242 if absolute > abs_tol: 

243 raise ValueError( # pragma: no cover 

244 "Example %d, inferred and expected resuls are different " 

245 "for output %d: abs=%r rel=%r (runtime=%r)." 

246 "" % (index, i, absolute, relative, runtime)) 

247 rows.append(dict(name=name, i=i, abs=absolute, rel=relative)) 

248 return rows