Patches (PyTorch export)#

This page answers common “how do I…” questions for applying and writing patches — temporary function replacements that make torch.export.export() succeed on models that would otherwise crash or produce incorrect graphs during symbolic tracing.

Note

Patches are only relevant when exporting a torch.nn.Module with torch.export.export(). They have no effect on ONNX models built directly via the builder APIs or when exporting from scikit-learn, TensorFlow or other non-PyTorch frameworks.

For background on why patches are needed, see Patches (torch export).


How to apply the built-in patches when exporting#

Wrap the torch.export.export() call with apply_patches_for_model(). Use patch_torch=True to activate the patches that fix symbolic-shape handling inside torch, and patch_transformers=True for models that use 🤗 Transformers internals (e.g. RotaryEmbedding).

The example below exports TinyBroadcastAddModel, whose output has the symbolic shape (batch, max(d1, d2)) due to broadcasting. This model requires both patch_torch=True and torch.fx.experimental._config.patch(backed_size_oblivious=True) to export successfully:

<<<

import torch
from yobx.torch import apply_patches_for_model, use_dyn_not_str
from yobx.torch.tiny_models import TinyBroadcastAddModel

model = TinyBroadcastAddModel()
inputs = TinyBroadcastAddModel._export_inputs()
dynamic_shapes = use_dyn_not_str(TinyBroadcastAddModel._dynamic_shapes())

with (
    torch.fx.experimental._config.patch(backed_size_oblivious=True),
    apply_patches_for_model(patch_torch=True) as details,
):
    ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=dynamic_shapes)

print(f"Applied {details.n_patches} patch(es).")
print(ep)

>>>

    Applied 5 patch(es).
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[s77, s27]", y: "f32[s17, s94]"):
                # File: /home/runner/work/xadupre.github.io/xadupre.github.io/yet-another-onnx-builder/yobx/torch/tiny_models.py:129 in forward, code: return x + y
                add: "f32[Max(s17, s77), Max(s27, s94)]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
                return (add,)
        
    Graph signature: 
        # inputs
        x: USER_INPUT
        y: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {s77: VR[0, int_oo], s27: VR[0, int_oo], s17: VR[0, int_oo], s94: VR[1, 2]}

The patches are automatically removed when the with block exits, leaving the original PyTorch functions fully restored.


How to list the patches that were applied#

The context manager yields a PatchDetails object. Iterate over it to see the name and family of every PatchInfo that was registered:

<<<

from yobx.torch import apply_patches_for_model

with apply_patches_for_model(patch_torch=True) as details:
    for patch in details:
        print(f"[{patch.family}] {patch.name}")

>>>

    [torch] _print_Symbol
    [torch] patched_infer_size
    [torch] patched__broadcast_shapes
    [torch] patched__get_range_constraints
    [torch] patched__maybe_broadcast

How to view the diff for each patch#

format_diff() returns a unified diff that shows exactly what changed between the original PyTorch function and the patched replacement. This is useful for auditing what the library is doing and for debugging unexpected behaviour.

<<<

from yobx.torch import apply_patches_for_model

with apply_patches_for_model(patch_torch=True) as details:
    pass  # patches are removed on exit but diffs remain accessible

# Show the first few lines of each diff.
for patch in details:
    diff_lines = patch.format_diff(format="raw").splitlines()
    print(f"=== {patch.name} ===")
    print("\n".join(diff_lines[:8]))
    print("...")
    print()

>>>

    === _print_Symbol ===
    torch: DynamicDimConstraintPrinter._print_Symbol -> patched_DynamicDimConstraintPrinter._print_Symbol
    --- original
    +++ rewritten
    @@ -1,6 +1,7 @@
    -def _print_Symbol(self, expr: sympy.Symbol) -> str:
    -    if not isinstance(expr, sympy.Symbol):
    -        raise AssertionError(f"Expected sympy.Symbol, got {type(expr)}")
    -    if not self.symbol_to_source.get(expr):
    ...
    
    === patched_infer_size ===
    torch: infer_size -> patched_infer_size
    --- original
    +++ rewritten
    @@ -1,10 +1,15 @@
    -def infer_size(a: Sequence[IntLikeType], b: Sequence[IntLikeType]) -> tuple[IntLikeType, ...]:
    -    from torch.fx.experimental.symbolic_shapes import guard_or_false
    +def patched_infer_size(a, b):
    +    """
    ...
    
    === patched__broadcast_shapes ===
    torch: _broadcast_shapes -> patched__broadcast_shapes
    --- original
    +++ rewritten
    @@ -1,12 +1,12 @@
    -def _broadcast_shapes(*_shapes):
    -    from torch.fx.experimental.symbolic_shapes import (
    -        guard_or_false,
    -        guarding_hint_or_throw,
    ...
    
    === patched__get_range_constraints ===
    torch: _get_range_constraints -> patched__get_range_constraints
    --- original
    +++ rewritten
    @@ -1,42 +1,36 @@
    -def _get_range_constraints(
    +def patched__get_range_constraints(
         mod: torch.nn.Module,
    -    export_artifact: ExportArtifact,
    ...
    
    === patched__maybe_broadcast ===
    torch: _maybe_broadcast -> patched__maybe_broadcast
    --- original
    +++ rewritten
    @@ -1,15 +1,18 @@
    -def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
    +def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
    +    """
    +    Patches ``torch._refs._maybe_broadcast``.
    ...

Pass format="rst" to get a reStructuredText block with a cross-reference anchor, which is how Patches List is generated.


How to write a custom patch#

Use make() to create a PatchInfo that swaps a function or method in a module for the duration of the export:

<<<

import torch
import torch._refs
from yobx.helpers.patch_helper import PatchInfo


def my_broadcast_shapes(*_shapes):
    """Minimal stand-in that returns the first non-empty shape."""
    for s in _shapes:
        if s:
            return list(s)
    return []


patch = PatchInfo.make(
    my_broadcast_shapes,
    torch._refs,
    "_broadcast_shapes",
    family="torch",
)

patch.do()
print("active patch:", patch.name)
print("patched function:", torch._refs._broadcast_shapes.__name__)

patch.undo()
print("after undo:", torch._refs._broadcast_shapes.__name__)

>>>

    active patch: my_broadcast_shapes
    patched function: my_broadcast_shapes
    after undo: _broadcast_shapes

The four arguments to make() are:

  • patch — the replacement callable.

  • module_or_class — the module or class whose attribute is replaced.

  • method_or_function_name — the attribute name (string) to patch.

  • family — a free-form category label (e.g. "torch" or "transformers") used in diffs and reports.

To add the custom patch alongside the built-in ones, pass it via the extra_patches argument of apply_patches_for_model():

<<<

import torch
import torch._refs
from yobx.helpers.patch_helper import PatchInfo
from yobx.torch import apply_patches_for_model


def my_broadcast_shapes(*_shapes):
    """Minimal stand-in that returns the first non-empty shape."""
    for s in _shapes:
        if s:
            return list(s)
    return []


my_patch = PatchInfo.make(
    my_broadcast_shapes,
    torch._refs,
    "_broadcast_shapes",
    family="custom",
)

with apply_patches_for_model(extra_patches=[my_patch]) as details:
    print(f"Total patches: {details.n_patches}")
    for p in details:
        print(f"  [{p.family}] {p.name}")
    print("patched function:", torch._refs._broadcast_shapes.__name__)

print("after context, function restored:", torch._refs._broadcast_shapes.__name__)

>>>

    Total patches: 1
      [custom] my_broadcast_shapes
    patched function: my_broadcast_shapes
    after context, function restored: _broadcast_shapes

How to identify which patches affected the exported graph#

After export, call patches_involved_in_graph() with the torch.fx.Graph from the ExportedProgram. The method cross-references the stack_trace metadata on each FX node with the source location of every registered patch and returns (PatchInfo, [node, …]) pairs.

<<<

import torch
from yobx.torch import apply_patches_for_model, use_dyn_not_str
from yobx.torch.tiny_models import TinyBroadcastAddModel

model = TinyBroadcastAddModel()
inputs = TinyBroadcastAddModel._export_inputs()
dynamic_shapes = use_dyn_not_str(TinyBroadcastAddModel._dynamic_shapes())

with (
    torch.fx.experimental._config.patch(backed_size_oblivious=True),
    apply_patches_for_model(patch_torch=True) as details,
):
    ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=dynamic_shapes)

patches = details.patches_involved_in_graph(ep.graph)
print(f"Patches that contributed nodes: {len(patches)}")
for patch_info, nodes in patches:
    node_names = [n.name for n in nodes]
    print(f"  {patch_info.name}: {node_names}")

>>>

    Patches that contributed nodes: 0

Use make_report() to produce a human-readable summary of every involved patch together with its diff:

report = details.make_report(patches, format="raw")
print(report)

See also

Patches (torch export) — background on why patches are needed and how the PatchInfo and PatchDetails data structures work.

Patches List — the full list of shipped patches with unified diffs.

Applying patches to a model and displaying the diff — a gallery example that applies patches, displays the diffs, and identifies which patches were exercised when exporting a real Transformers model.

yobx.helpers.patch_helper — API reference for PatchInfo and PatchDetails.

yobx.torch.patch — API reference for apply_patches_for_model().