Coverage for mlprodict/tools/zoo.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Tools to test models from the :epkg:`ONNX Zoo`.
5.. versionadded:: 0.6
6"""
7import os
8import urllib.request
9from collections import OrderedDict
10import numpy
11from onnx import TensorProto, numpy_helper
12try:
13 from .ort_wrapper import InferenceSession
14except ImportError:
15 from mlprodict.tools.ort_wrapper import InferenceSession
18def short_list_zoo_models():
19 """
20 Returns a short list from :epkg:`ONNX Zoo`.
22 :return: list of dictionaries.
24 .. runpython::
25 :showcode:
26 :warningout: DeprecationWarning
28 import pprint
29 from mlprodict.tools.zoo import short_list_zoo_models
30 pprint.pprint(short_list_zoo_models())
31 """
32 return [
33 dict(name="mobilenet",
34 model="https://github.com/onnx/models/raw/main/vision/"
35 "classification/mobilenet/model/mobilenetv2-7.tar.gz"),
36 dict(name="resnet18",
37 model="https://github.com/onnx/models/raw/main/vision/"
38 "classification/resnet/model/resnet18-v1-7.tar.gz"),
39 dict(name="squeezenet",
40 model="https://github.com/onnx/models/raw/main/vision/"
41 "classification/squeezenet/model/squeezenet1.0-9.tar.gz",
42 folder="squeezenet"),
43 dict(name="densenet121",
44 model="https://github.com/onnx/models/raw/main/vision/"
45 "classification/densenet-121/model/densenet-9.tar.gz",
46 folder="densenet121"),
47 dict(name="inception2",
48 model="https://github.com/onnx/models/raw/main/vision/"
49 "classification/inception_and_googlenet/inception_v2/"
50 "model/inception-v2-9.tar.gz"),
51 dict(name="shufflenet",
52 model="https://github.com/onnx/models/raw/main/vision/"
53 "classification/shufflenet/model/shufflenet-9.tar.gz"),
54 dict(name="efficientnet-lite4",
55 model="https://github.com/onnx/models/raw/main/vision/"
56 "classification/efficientnet-lite4/model/"
57 "efficientnet-lite4-11.tar.gz"),
58 ]
61def _download_url(url, output_path, name, verbose=False):
62 if verbose: # pragma: no cover
63 from tqdm import tqdm
65 class DownloadProgressBar(tqdm):
66 "progress bar hook"
68 def update_to(self, b=1, bsize=1, tsize=None):
69 "progress bar hook"
70 if tsize is not None:
71 self.total = tsize
72 self.update(b * bsize - self.n)
74 with DownloadProgressBar(unit='B', unit_scale=True,
75 miniters=1, desc=name) as t:
76 urllib.request.urlretrieve(
77 url, filename=output_path, reporthook=t.update_to)
78 else:
79 urllib.request.urlretrieve(url, filename=output_path)
82def load_data(folder):
83 """
84 Restores protobuf data stored in a folder.
86 :param folder: folder
87 :return: dictionary
88 """
89 res = OrderedDict()
90 res['in'] = OrderedDict()
91 res['out'] = OrderedDict()
92 files = os.listdir(folder)
93 for name in files:
94 noext, ext = os.path.splitext(name)
95 if ext == '.pb':
96 data = TensorProto()
97 with open(os.path.join(folder, name), 'rb') as f:
98 data.ParseFromString(f.read())
99 if noext.startswith('input'):
100 res['in'][noext] = numpy_helper.to_array(data)
101 elif noext.startswith('output'):
102 res['out'][noext] = numpy_helper.to_array(data)
103 else:
104 raise ValueError( # pragma: no cover
105 "Unable to guess anything about %r." % noext)
107 return res
110def download_model_data(name, model=None, cache=None, verbose=False):
111 """
112 Downloads a model and returns a link to the local
113 :epkg:`ONNX` file and data which can be used as inputs.
115 :param name: model name (see @see fn short_list_zoo_models)
116 :param model: url or empty to get the default value
117 returned by @see fn short_list_zoo_models)
118 :param cache: folder to cache the downloaded data
119 :param verbose: display a progress bar
120 :return: local onnx file, input data
121 """
122 suggested_folder = None
123 if model is None:
124 model_list = short_list_zoo_models()
125 for mod in model_list:
126 if mod['name'] == name:
127 model = mod['model']
128 if 'folder' in mod: # pylint: disable=R1715
129 suggested_folder = mod['folder']
130 break
131 if model is None:
132 raise ValueError(
133 "Unable to find a default value for name=%r." % name)
135 # downloads
136 last_name = model.split('/')[-1]
137 if cache is None:
138 cache = os.path.abspath('.') # pragma: no cover
139 dest = os.path.join(cache, last_name)
140 if not os.path.exists(dest):
141 _download_url(model, dest, name, verbose=verbose)
142 size = os.stat(dest).st_size
143 if size < 2 ** 20: # pragma: no cover
144 os.remove(dest)
145 raise ConnectionError(
146 "Unable to download model from %r." % model)
148 outtar = os.path.splitext(dest)[0]
149 if not os.path.exists(outtar):
150 from pyquickhelper.filehelper.compression_helper import (
151 ungzip_files)
152 ungzip_files(dest, unzip=False, where_to=cache, remove_space=False)
154 onnx_file = os.path.splitext(outtar)[0]
155 if not os.path.exists(onnx_file):
156 from pyquickhelper.filehelper.compression_helper import (
157 untar_files)
158 untar_files(outtar, where_to=cache)
160 if suggested_folder is not None:
161 fold_onnx = [suggested_folder]
162 else:
163 fold_onnx = [onnx_file, onnx_file.split('-')[0],
164 '-'.join(onnx_file.split('-')[:-1]),
165 '-'.join(onnx_file.split('-')[:-1]).replace('-', '_')]
166 fold_onnx_ok = [_ for _ in fold_onnx if os.path.exists(_)]
167 if len(fold_onnx_ok) != 1:
168 raise FileNotFoundError( # pragma: no cover
169 "Unable to find an existing folder among %r." % fold_onnx)
170 onnx_file = fold_onnx_ok[0]
172 onnx_files = [_ for _ in os.listdir(onnx_file) if _.endswith(".onnx")]
173 if len(onnx_files) != 1:
174 raise FileNotFoundError( # pragma: no cover
175 "Unable to find any onnx file in %r." % onnx_files)
176 final_onnx = os.path.join(onnx_file, onnx_files[0])
178 # data
179 data = [_ for _ in os.listdir(onnx_file)
180 if os.path.isdir(os.path.join(onnx_file, _))]
181 examples = OrderedDict()
182 for f in data:
183 examples[f] = load_data(os.path.join(onnx_file, f))
185 return final_onnx, examples
188def verify_model(onnx_file, examples, runtime=None, abs_tol=5e-4,
189 verbose=0, fLOG=None):
190 """
191 Verifies a model.
193 :param onnx_file: ONNX file
194 :param examples: list of examples to verify
195 :param runtime: a runtime to use
196 :param abs_tol: error tolerance when checking the output
197 :param verbose: verbosity level for for runtime other than
198 `'onnxruntime'`
199 :param fLOG: logging function when `verbose > 0`
200 :return: errors for every sample
201 """
202 if runtime in ('onnxruntime', 'onnxruntime-cuda'):
203 sess = InferenceSession(onnx_file, runtime=runtime)
204 meth = lambda data, s=sess: s.run(None, data)
205 names = [p.name for p in sess.get_inputs()]
206 onames = list(range(len(sess.get_outputs())))
207 else:
208 def _lin_(sess, data, names):
209 r = sess.run(data, verbose=verbose, fLOG=fLOG)
210 return [r[n] for n in names]
212 from ..onnxrt import OnnxInference
213 sess = OnnxInference(
214 onnx_file, runtime=runtime,
215 runtime_options=dict(log_severity_level=3))
216 names = sess.input_names
217 onames = sess.output_names
218 meth = lambda data, s=sess, ns=onames: _lin_(s, data, ns)
220 rows = []
221 for index, (name, data_inout) in enumerate(examples.items()):
222 data = data_inout["in"]
223 if len(data) != len(names):
224 raise RuntimeError( # pragma: no cover
225 "Mismathed number of inputs %d != %d\ninputs: %r\nmodel: %r."
226 "" % (len(data), len(names), list(sorted(data)), names))
227 inputs = {n: data[v] for n, v in zip(names, data)}
228 outputs = meth(inputs)
229 expected = data_inout['out']
230 if len(outputs) != len(onames):
231 raise RuntimeError( # pragma: no cover
232 "Number of outputs %d is != expected outputs %d." % (
233 len(outputs), len(onames)))
234 for i, (output, expect) in enumerate(zip(outputs, expected.items())):
235 if output.shape != expect[1].shape:
236 raise ValueError( # pragma: no cover
237 "Shape mismatch got %r != expected %r." % (
238 output.shape, expect[1].shape))
239 diff = numpy.abs(output - expect[1]).ravel()
240 absolute = diff.max()
241 relative = absolute / numpy.median(diff) if absolute > 0 else 0.
242 if absolute > abs_tol:
243 raise ValueError( # pragma: no cover
244 "Example %d, inferred and expected resuls are different "
245 "for output %d: abs=%r rel=%r (runtime=%r)."
246 "" % (index, i, absolute, relative, runtime))
247 rows.append(dict(name=name, i=i, abs=absolute, rel=relative))
248 return rows