# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import functools
import glob
import os
import re
import shutil
import sys
import tarfile
import tempfile
import time
import unittest
from collections import defaultdict
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Pattern,
Sequence,
Set,
Type,
Union,
)
from urllib.request import urlretrieve
import numpy as np # type: ignore
import onnx
from onnx import ModelProto, NodeProto, TypeProto, numpy_helper
from onnx.backend.base import Backend
from ..case.test_case import TestCase
from ..loader import load_model_tests
from .item import TestItem
class BackendIsNotSupposedToImplementIt(unittest.SkipTest):
pass
def retry_excute(times: int) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
assert times >= 1
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(func)
def wrapped(*args: Any, **kwargs: Any) -> Any:
for i in range(1, times + 1):
try:
return func(*args, **kwargs)
except Exception:
print(f"{i} times tried")
if i == times:
raise
time.sleep(5 * i)
return wrapped
return wrapper
[docs]class Runner:
def __init__(
self, backend: Type[Backend], parent_module: Optional[str] = None
) -> None:
self.backend = backend
self._parent_module = parent_module
self._include_patterns: Set[Pattern[str]] = set()
self._exclude_patterns: Set[Pattern[str]] = set()
self._xfail_patterns: Set[Pattern[str]] = set()
# This is the source of the truth of all test functions.
# Properties `test_cases`, `test_suite` and `tests` will be
# derived from it.
# {category: {name: func}}
self._test_items: Dict[str, Dict[str, TestItem]] = defaultdict(dict)
for rt in load_model_tests(kind="node"):
self._add_model_test(rt, "Node")
for rt in load_model_tests(kind="real"):
self._add_model_test(rt, "Real")
for rt in load_model_tests(kind="simple"):
self._add_model_test(rt, "Simple")
for ct in load_model_tests(kind="pytorch-converted"):
self._add_model_test(ct, "PyTorchConverted")
for ot in load_model_tests(kind="pytorch-operator"):
self._add_model_test(ot, "PyTorchOperator")
def _get_test_case(self, name: str) -> Type[unittest.TestCase]:
test_case = type(str(name), (unittest.TestCase,), {})
if self._parent_module:
test_case.__module__ = self._parent_module
return test_case
def include(self, pattern: str) -> Runner:
self._include_patterns.add(re.compile(pattern))
return self
def exclude(self, pattern: str) -> Runner:
self._exclude_patterns.add(re.compile(pattern))
return self
def xfail(self, pattern: str) -> Runner:
self._xfail_patterns.add(re.compile(pattern))
return self
def enable_report(self) -> Runner:
import pytest # type: ignore
for category, items_map in self._test_items.items():
for name, item in items_map.items():
item.func = pytest.mark.onnx_coverage(item.proto, category)(item.func)
return self
@property
def _filtered_test_items(self) -> Dict[str, Dict[str, TestItem]]:
filtered: Dict[str, Dict[str, TestItem]] = {}
for category, items_map in self._test_items.items():
filtered[category] = {}
for name, item in items_map.items():
if self._include_patterns and (
not any(include.search(name) for include in self._include_patterns)
):
item.func = unittest.skip("no matched include pattern")(item.func)
for exclude in self._exclude_patterns:
if exclude.search(name):
item.func = unittest.skip(
f'matched exclude pattern "{exclude.pattern}"'
)(item.func)
for xfail in self._xfail_patterns:
if xfail.search(name):
item.func = unittest.expectedFailure(item.func)
filtered[category][name] = item
return filtered
@property
def test_cases(self) -> Dict[str, Type[unittest.TestCase]]:
"""
List of test cases to be applied on the parent scope
Example usage:
globals().update(BackendTest(backend).test_cases)
"""
test_cases = {}
for category, items_map in self._filtered_test_items.items():
test_case_name = f"OnnxBackend{category}Test"
test_case = self._get_test_case(test_case_name)
for name, item in sorted(items_map.items()):
setattr(test_case, name, item.func)
test_cases[test_case_name] = test_case
return test_cases
@property
def test_suite(self) -> unittest.TestSuite:
"""
TestSuite that can be run by TestRunner
Example usage:
unittest.TextTestRunner().run(BackendTest(backend).test_suite)
"""
suite = unittest.TestSuite()
for case in sorted(self.test_cases.values()):
suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(case))
return suite
# For backward compatibility (we used to expose `.tests`)
@property
def tests(self) -> Type[unittest.TestCase]:
"""
One single unittest.TestCase that hosts all the test functions
Example usage:
onnx_backend_tests = BackendTest(backend).tests
"""
tests = self._get_test_case("OnnxBackendTest")
for items_map in sorted(self._filtered_test_items.values()):
for name, item in sorted(items_map.items()):
setattr(tests, name, item.func)
return tests
@classmethod
def assert_similar_outputs(
cls,
ref_outputs: Sequence[Any],
outputs: Sequence[Any],
rtol: float,
atol: float,
) -> None:
np.testing.assert_equal(len(outputs), len(ref_outputs))
for i in range(len(outputs)):
if isinstance(outputs[i], (list, tuple)):
for j in range(len(outputs[i])):
cls.assert_similar_outputs(
ref_outputs[i][j], outputs[i][j], rtol, atol
)
else:
np.testing.assert_equal(outputs[i].dtype, ref_outputs[i].dtype)
if ref_outputs[i].dtype == np.object:
np.testing.assert_array_equal(outputs[i], ref_outputs[i])
else:
np.testing.assert_allclose(
outputs[i], ref_outputs[i], rtol=rtol, atol=atol
)
@classmethod
@retry_excute(3)
def download_model(
cls, model_test: TestCase, model_dir: str, models_dir: str
) -> None:
# On Windows, NamedTemporaryFile can not be opened for a
# second time
download_file = tempfile.NamedTemporaryFile(delete=False)
try:
download_file.close()
assert model_test.url
print(
"Start downloading model {} from {}".format(
model_test.model_name, model_test.url
)
)
urlretrieve(model_test.url, download_file.name)
print("Done")
with tarfile.open(download_file.name) as t:
t.extractall(models_dir)
except Exception as e:
print(
"Failed to prepare data for model {}: {}".format(
model_test.model_name, e
)
)
raise
finally:
os.remove(download_file.name)
@classmethod
def prepare_model_data(cls, model_test: TestCase) -> str:
onnx_home = os.path.expanduser(
os.getenv("ONNX_HOME", os.path.join("~", ".onnx"))
)
models_dir = os.getenv("ONNX_MODELS", os.path.join(onnx_home, "models"))
model_dir: str = os.path.join(models_dir, model_test.model_name)
if not os.path.exists(os.path.join(model_dir, "model.onnx")):
if os.path.exists(model_dir):
bi = 0
while True:
dest = f"{model_dir}.old.{bi}"
if os.path.exists(dest):
bi += 1
continue
shutil.move(model_dir, dest)
break
os.makedirs(model_dir)
cls.download_model(
model_test=model_test, model_dir=model_dir, models_dir=models_dir
)
return model_dir
def _add_test(
self,
category: str,
test_name: str,
test_func: Callable[..., Any],
report_item: List[Optional[Union[ModelProto, NodeProto]]],
devices: Iterable[str] = ("CPU", "CUDA"),
) -> None:
# We don't prepend the 'test_' prefix to improve greppability
if not test_name.startswith("test_"):
raise ValueError(f"Test name must start with test_: {test_name}")
def add_device_test(device: str) -> None:
device_test_name = f"{test_name}_{device.lower()}"
if device_test_name in self._test_items[category]:
raise ValueError(
'Duplicated test name "{}" in category "{}"'.format(
device_test_name, category
)
)
@unittest.skipIf( # type: ignore
not self.backend.supports_device(device),
f"Backend doesn't support device {device}",
)
@functools.wraps(test_func)
def device_test_func(*args: Any, **kwargs: Any) -> Any:
try:
return test_func(*args, device=device, **kwargs)
except BackendIsNotSupposedToImplementIt as e:
# hacky verbose reporting
if "-v" in sys.argv or "--verbose" in sys.argv:
print(
"Test {} is effectively skipped: {}".format(
device_test_name, e
)
)
self._test_items[category][device_test_name] = TestItem(
device_test_func, report_item
)
for device in devices:
add_device_test(device)
def _add_model_test(self, model_test: TestCase, kind: str) -> None:
# model is loaded at runtime, note sometimes it could even
# never loaded if the test skipped
model_marker: List[Optional[Union[ModelProto, NodeProto]]] = [None]
def run(test_self: Any, device: str) -> None:
if model_test.model_dir is None:
model_dir = self.prepare_model_data(model_test)
else:
model_dir = model_test.model_dir
model_pb_path = os.path.join(model_dir, "model.onnx")
model = onnx.load(model_pb_path)
model_marker[0] = model
if (
hasattr(self.backend, "is_compatible")
and callable(self.backend.is_compatible)
and not self.backend.is_compatible(model)
):
raise unittest.SkipTest("Not compatible with backend")
prepared_model = self.backend.prepare(model, device)
assert prepared_model is not None
# TODO after converting all npz files to protobuf, we can delete this.
for test_data_npz in glob.glob(os.path.join(model_dir, "test_data_*.npz")):
test_data = np.load(test_data_npz, encoding="bytes")
inputs = list(test_data["inputs"])
outputs = list(prepared_model.run(inputs))
ref_outputs = test_data["outputs"]
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
for test_data_dir in glob.glob(os.path.join(model_dir, "test_data_set*")):
inputs = []
inputs_num = len(glob.glob(os.path.join(test_data_dir, "input_*.pb")))
for i in range(inputs_num):
input_file = os.path.join(test_data_dir, f"input_{i}.pb")
self._load_proto(input_file, inputs, model.graph.input[i].type)
ref_outputs = []
ref_outputs_num = len(
glob.glob(os.path.join(test_data_dir, "output_*.pb"))
)
for i in range(ref_outputs_num):
output_file = os.path.join(test_data_dir, f"output_{i}.pb")
self._load_proto(
output_file, ref_outputs, model.graph.output[i].type
)
outputs = list(prepared_model.run(inputs))
self.assert_similar_outputs(
ref_outputs, outputs, rtol=model_test.rtol, atol=model_test.atol
)
self._add_test(kind + "Model", model_test.name, run, model_marker)
def _load_proto(
self,
proto_filename: str,
target_list: List[Union[np.ndarray, List[Any]]],
model_type_proto: TypeProto,
) -> None:
with open(proto_filename, "rb") as f:
protobuf_content = f.read()
if model_type_proto.HasField("sequence_type"):
sequence = onnx.SequenceProto()
sequence.ParseFromString(protobuf_content)
target_list.append(numpy_helper.to_list(sequence))
elif model_type_proto.HasField("tensor_type"):
tensor = onnx.TensorProto()
tensor.ParseFromString(protobuf_content)
target_list.append(numpy_helper.to_array(tensor))
elif model_type_proto.HasField("optional_type"):
optional = onnx.OptionalProto()
optional.ParseFromString(protobuf_content)
target_list.append(numpy_helper.to_optional(optional))
else:
print(
"Loading proto of that specific type (Map/Sparse Tensor) is currently not supported"
)