2021-08-12 A few tricks for tf2onnx#
A few things I tend to forget. To run a specific test on a specific opset.
python tests/test_backend.py --opset 12 BackendTests.test_rfft2d_ops_specific_dimension
Optimisation of an onnx file. It applies the whole list of optimizers available in tensorflow-onnx.
import logging
import onnx
from onnx import helper
from tf2onnx.graph import GraphUtil
from tf2onnx import logging, optimizer, constants
from tf2onnx.late_rewriters import rewrite_channels_first, rewrite_channels_last
logging.basicConfig(level=logging.DEBUG)
def load_graph(fname, target):
model_proto = onnx.ModelProto()
with open(fname, "rb") as f:
data = f.read()
model_proto.ParseFromString(data)
g = GraphUtil.create_graph_from_onnx_model(model_proto, target)
return g, model_proto
def optimize(input, output):
g, org_model_proto = load_graph(input, [])
if g.is_target(constants.TARGET_CHANNELS_FIRST):
g.reset_nodes(rewrite_channels_first(g, g.get_nodes()))
if g.is_target(constants.TARGET_CHANNELS_LAST):
g.reset_nodes(rewrite_channels_last(g, g.get_nodes()))
g = optimizer.optimize_graph(g)
onnx_graph = g.make_graph(
org_model_proto.graph.doc_string + " (+tf2onnx/onnx-optimize)")
kwargs = GraphUtil.get_onnx_model_properties(org_model_proto)
model_proto = helper.make_model(onnx_graph, **kwargs)
with open(output, "wb") as f:
f.write(model_proto.SerializeToString())
optimize("debug_noopt.onnx", "debug_opt.onnx")