Patches (PyTorch export)#
This page answers common “how do I…” questions for applying and writing
patches — temporary function replacements that make torch.export.export()
succeed on models that would otherwise crash or produce incorrect graphs during
symbolic tracing.
Note
Patches are only relevant when exporting a torch.nn.Module with
torch.export.export(). They have no effect on ONNX models built
directly via the builder APIs or when exporting from scikit-learn,
TensorFlow or other non-PyTorch frameworks.
For background on why patches are needed, see Patches (torch export).
How to apply the built-in patches when exporting#
Wrap the torch.export.export() call with
apply_patches_for_model(). Use patch_torch=True
to activate the patches that fix symbolic-shape handling inside torch,
and patch_transformers=True for models that use 🤗 Transformers internals
(e.g. RotaryEmbedding).
The example below exports TinyBroadcastAddModel,
whose output has the symbolic shape (batch, max(d1, d2)) due to broadcasting.
This model requires both patch_torch=True and
torch.fx.experimental._config.patch(backed_size_oblivious=True) to export
successfully:
<<<
import torch
from yobx.torch import apply_patches_for_model, use_dyn_not_str
from yobx.torch.tiny_models import TinyBroadcastAddModel
model = TinyBroadcastAddModel()
inputs = TinyBroadcastAddModel._export_inputs()
dynamic_shapes = use_dyn_not_str(TinyBroadcastAddModel._dynamic_shapes())
with (
torch.fx.experimental._config.patch(backed_size_oblivious=True),
apply_patches_for_model(patch_torch=True) as details,
):
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=dynamic_shapes)
print(f"Applied {details.n_patches} patch(es).")
print(ep)
>>>
Applied 5 patch(es).
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s77, s27]", y: "f32[s17, s94]"):
# File: /home/runner/work/xadupre.github.io/xadupre.github.io/yet-another-onnx-builder/yobx/torch/tiny_models.py:129 in forward, code: return x + y
add: "f32[Max(s17, s77), Max(s27, s94)]" = torch.ops.aten.add.Tensor(x, y); x = y = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s77: VR[0, int_oo], s27: VR[0, int_oo], s17: VR[0, int_oo], s94: VR[1, 2]}
The patches are automatically removed when the with block exits, leaving
the original PyTorch functions fully restored.
How to list the patches that were applied#
The context manager yields a PatchDetails
object. Iterate over it to see the name and family of every
PatchInfo that was registered:
<<<
from yobx.torch import apply_patches_for_model
with apply_patches_for_model(patch_torch=True) as details:
for patch in details:
print(f"[{patch.family}] {patch.name}")
>>>
[torch] _print_Symbol
[torch] patched_infer_size
[torch] patched__broadcast_shapes
[torch] patched__get_range_constraints
[torch] patched__maybe_broadcast
How to view the diff for each patch#
format_diff() returns a unified
diff that shows exactly what changed between the original PyTorch function and
the patched replacement. This is useful for auditing what the library is
doing and for debugging unexpected behaviour.
<<<
from yobx.torch import apply_patches_for_model
with apply_patches_for_model(patch_torch=True) as details:
pass # patches are removed on exit but diffs remain accessible
# Show the first few lines of each diff.
for patch in details:
diff_lines = patch.format_diff(format="raw").splitlines()
print(f"=== {patch.name} ===")
print("\n".join(diff_lines[:8]))
print("...")
print()
>>>
=== _print_Symbol ===
torch: DynamicDimConstraintPrinter._print_Symbol -> patched_DynamicDimConstraintPrinter._print_Symbol
--- original
+++ rewritten
@@ -1,6 +1,7 @@
-def _print_Symbol(self, expr: sympy.Symbol) -> str:
- if not isinstance(expr, sympy.Symbol):
- raise AssertionError(f"Expected sympy.Symbol, got {type(expr)}")
- if not self.symbol_to_source.get(expr):
...
=== patched_infer_size ===
torch: infer_size -> patched_infer_size
--- original
+++ rewritten
@@ -1,10 +1,15 @@
-def infer_size(a: Sequence[IntLikeType], b: Sequence[IntLikeType]) -> tuple[IntLikeType, ...]:
- from torch.fx.experimental.symbolic_shapes import guard_or_false
+def patched_infer_size(a, b):
+ """
...
=== patched__broadcast_shapes ===
torch: _broadcast_shapes -> patched__broadcast_shapes
--- original
+++ rewritten
@@ -1,12 +1,12 @@
-def _broadcast_shapes(*_shapes):
- from torch.fx.experimental.symbolic_shapes import (
- guard_or_false,
- guarding_hint_or_throw,
...
=== patched__get_range_constraints ===
torch: _get_range_constraints -> patched__get_range_constraints
--- original
+++ rewritten
@@ -1,42 +1,36 @@
-def _get_range_constraints(
+def patched__get_range_constraints(
mod: torch.nn.Module,
- export_artifact: ExportArtifact,
...
=== patched__maybe_broadcast ===
torch: _maybe_broadcast -> patched__maybe_broadcast
--- original
+++ rewritten
@@ -1,15 +1,18 @@
-def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
+def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
+ """
+ Patches ``torch._refs._maybe_broadcast``.
...
Pass format="rst" to get a reStructuredText block with a cross-reference
anchor, which is how Patches List is generated.
How to write a custom patch#
Use make() to create a
PatchInfo that swaps a function or method
in a module for the duration of the export:
<<<
import torch
import torch._refs
from yobx.helpers.patch_helper import PatchInfo
def my_broadcast_shapes(*_shapes):
"""Minimal stand-in that returns the first non-empty shape."""
for s in _shapes:
if s:
return list(s)
return []
patch = PatchInfo.make(
my_broadcast_shapes,
torch._refs,
"_broadcast_shapes",
family="torch",
)
patch.do()
print("active patch:", patch.name)
print("patched function:", torch._refs._broadcast_shapes.__name__)
patch.undo()
print("after undo:", torch._refs._broadcast_shapes.__name__)
>>>
active patch: my_broadcast_shapes
patched function: my_broadcast_shapes
after undo: _broadcast_shapes
The four arguments to make() are:
patch — the replacement callable.
module_or_class — the module or class whose attribute is replaced.
method_or_function_name — the attribute name (string) to patch.
family — a free-form category label (e.g.
"torch"or"transformers") used in diffs and reports.
To add the custom patch alongside the built-in ones, pass it via the
extra_patches argument of
apply_patches_for_model():
<<<
import torch
import torch._refs
from yobx.helpers.patch_helper import PatchInfo
from yobx.torch import apply_patches_for_model
def my_broadcast_shapes(*_shapes):
"""Minimal stand-in that returns the first non-empty shape."""
for s in _shapes:
if s:
return list(s)
return []
my_patch = PatchInfo.make(
my_broadcast_shapes,
torch._refs,
"_broadcast_shapes",
family="custom",
)
with apply_patches_for_model(extra_patches=[my_patch]) as details:
print(f"Total patches: {details.n_patches}")
for p in details:
print(f" [{p.family}] {p.name}")
print("patched function:", torch._refs._broadcast_shapes.__name__)
print("after context, function restored:", torch._refs._broadcast_shapes.__name__)
>>>
Total patches: 1
[custom] my_broadcast_shapes
patched function: my_broadcast_shapes
after context, function restored: _broadcast_shapes
How to identify which patches affected the exported graph#
After export, call
patches_involved_in_graph() with
the torch.fx.Graph from the
ExportedProgram. The method cross-references the
stack_trace metadata on each FX node with the source location of every
registered patch and returns (PatchInfo, [node, …]) pairs.
<<<
import torch
from yobx.torch import apply_patches_for_model, use_dyn_not_str
from yobx.torch.tiny_models import TinyBroadcastAddModel
model = TinyBroadcastAddModel()
inputs = TinyBroadcastAddModel._export_inputs()
dynamic_shapes = use_dyn_not_str(TinyBroadcastAddModel._dynamic_shapes())
with (
torch.fx.experimental._config.patch(backed_size_oblivious=True),
apply_patches_for_model(patch_torch=True) as details,
):
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=dynamic_shapes)
patches = details.patches_involved_in_graph(ep.graph)
print(f"Patches that contributed nodes: {len(patches)}")
for patch_info, nodes in patches:
node_names = [n.name for n in nodes]
print(f" {patch_info.name}: {node_names}")
>>>
Patches that contributed nodes: 0
Use make_report() to produce a
human-readable summary of every involved patch together with its diff:
report = details.make_report(patches, format="raw")
print(report)
See also
Patches (torch export) — background on why patches are needed and how
the PatchInfo and
PatchDetails data structures work.
Patches List — the full list of shipped patches with unified diffs.
Applying patches to a model and displaying the diff — a gallery example that applies patches, displays the diffs, and identifies which patches were exercised when exporting a real Transformers model.
yobx.helpers.patch_helper — API reference for
PatchInfo and
PatchDetails.
yobx.torch.patch — API reference for
apply_patches_for_model().