# Copyright (C) 2026 Embedl AB
"""Conversion ``Pattern`` subclasses for TensorRT.
Conversions are *structural* rewrites that change the graph topology so that
downstream fusion patterns can match. They run **before** fusion patterns.
Example: ``Flatten → Linear → Conv2d(1×1) → Flatten``.
"""
import logging
import operator
from typing import TypeAlias
import torch
from torch import fx, nn
from embedl_deploy._internal.core.match import match_tree
from embedl_deploy._internal.core.pattern import (
Pattern,
PatternMatch,
Tree,
Wildcard,
get_module,
resolve_module,
)
from embedl_deploy._internal.core.replace import (
ReplacementFn,
get_auto_name,
get_replaced_nodes,
replace_tree,
)
from embedl_deploy._internal.tensorrt.modules.attention import (
MHAInProjection,
ScaledDotProductAttention,
)
from embedl_deploy._internal.tensorrt.modules.swin_attention import (
SwinAttention,
SwinSpatialState,
SwinWindowPartition,
SwinWindowReverse,
)
try:
from torchvision.models.swin_transformer import (
shifted_window_attention,
)
except ImportError: # pragma: no cover
shifted_window_attention = None # type: ignore[assignment]
_LOG = logging.getLogger(__name__)
def _is_assert_noise(node: fx.Node) -> bool:
"""Return ``True`` for helper nodes typically used in assertions.
These nodes are graph noise in exported inference graphs and usually feed
directly into ``eq`` / ``_assert`` checks.
"""
if node.op == "call_function" and node.target is getattr:
return True
if node.op == "call_method" and node.target in {"dim", "size"}:
return True
return False
def _is_eq(node: fx.Node) -> bool:
"""Return ``True`` when `node` is an equality comparison."""
return node.op == "call_function" and node.target is operator.eq
def _is_assert(node: fx.Node) -> bool:
"""Return ``True`` when `node` is a torch assertion op."""
target = getattr(torch, "_assert", None)
return node.op == "call_function" and node.target is target
_OPTIONAL_ASSERT_NOISE = Wildcard(_is_assert_noise, quantifier="+")
def _is_dead_assert_noise(node: fx.Node) -> bool:
"""Return ``True`` for assertion-noise nodes with no users."""
return _is_assert_noise(node) and len(node.users) == 0
[docs]
class RemoveAssertPattern(Pattern):
"""Remove assertion subgraphs such as ``getattr → eq → _assert``.
timm models often contain shape checks in the traced FX graph (for
example ``assert x.dim() == 4``). They are not runtime compute ops and can
be safely erased for deployment graphs.
"""
is_conversion = True
tree: Tree = (_OPTIONAL_ASSERT_NOISE, _is_eq, _is_assert)
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(graph_module, pattern=self)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
return replace_tree(pattern_match, [])
[docs]
class RemoveDeadAssertPattern(Pattern):
"""Remove dead assertion nodes left after assert removal.
``RemoveAssertPattern`` can erase ``eq``/``_assert`` while leaving the
other ``getattr`` input dead. This cleanup pass removes those dead nodes.
"""
is_conversion = True
tree: Tree = (_is_dead_assert_noise,)
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(graph_module, pattern=self)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
return replace_tree(pattern_match, [])
def _has_input_shape(node: fx.Node) -> bool:
"""Return ``True`` when the first input to `node` has shape metadata."""
try:
if not isinstance(arg := node.args[0], fx.Node):
return False
if not hasattr(arg.meta["tensor_meta"], "shape"):
return False
except (AttributeError, IndexError, KeyError):
return False
return True
def _is_flatten(node: fx.Node) -> bool:
"""Return ``True`` when `node` is a flatten call with shape metadata."""
if node.op == "call_function":
is_flat = node.target is torch.flatten
elif node.op == "call_method":
is_flat = node.target == "flatten"
else:
is_flat = isinstance(get_module(node), nn.Flatten)
if not is_flat:
return False
if not _has_input_shape(node):
msg = f"Skipping {node.target}: missing shape metadata on input node"
_LOG.warning("%s", msg)
return False
return True
ElementWiseLike: TypeAlias = (
nn.Dropout
| nn.Dropout2d
| nn.ReLU
| nn.ReLU6
| nn.LeakyReLU
| nn.ELU
| nn.GELU
| nn.SiLU
| nn.Hardswish
| nn.Hardsigmoid
)
#: Optional element-wise ops that can appear between a ``flatten`` and
#: a ``Linear`` (activations, dropout).
_OPTIONAL_EW = Wildcard(ElementWiseLike)
def _get_reshape_insert(
flatten_node: fx.Node,
linear: nn.Linear,
) -> ReplacementFn:
"""Return a replacement function for reshaping conv inputs."""
def _insert(
graph_module: fx.GraphModule,
prev_args: tuple[fx.Node, ...],
) -> list[fx.Node]:
assert isinstance(flatten_node.args[0], fx.Node)
old_shape = flatten_node.args[0].meta["tensor_meta"].shape
new_shape = torch.Size([-1, linear.in_features, 1, 1])
if old_shape[1:] == new_shape[1:]:
return []
replaced = get_replaced_nodes(graph_module)
resolved = replaced.get(flatten_node, flatten_node)
with graph_module.graph.inserting_before(resolved):
node = graph_module.graph.call_method(
"reshape", (prev_args[0], *new_shape)
)
return [node]
return _insert
def _linear_to_conv1x1(linear: nn.Linear) -> nn.Conv2d:
"""Convert a ``Linear`` layer to an equivalent ``Conv2d(1×1)``."""
conv = nn.Conv2d(
linear.in_features,
linear.out_features,
kernel_size=1,
bias=linear.bias is not None,
)
conv.weight.data = linear.weight.data.clone().reshape(
linear.out_features,
linear.in_features,
1,
1,
)
if linear.bias is not None:
assert conv.bias is not None
conv.bias.data = linear.bias.data.clone()
return conv
[docs]
class FlattenLinearToConv1x1Pattern(Pattern):
"""Replace ``Flatten (4D→2D) → Linear`` with ``Conv2d(1×1) → Flatten``.
Many classification networks end with::
``AdaptiveAvgPool2d → flatten → [Dropout → ReLU →] Linear``
This conversion rewrites the tail into::
``AdaptiveAvgPool2d → [Dropout → ReLU →] Conv2d(1×1) → flatten``
Element-wise ops between flatten and Linear (activations, dropout) are
absorbed by a :class:`~embedl_deploy._internal.core.pattern.Wildcard` and
moved in front of the ``Conv2d`` in the replacement.
The resulting ``Conv2d`` can then be matched by downstream fusion and Q/DQ
patterns.
This is a *structural* conversion — it changes graph topology rather than
collapsing a chain into a fused module. It must be applied **before**
fusion patterns.
"""
is_conversion = True
tree: Tree = (_is_flatten, _OPTIONAL_EW, nn.Linear)
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(graph_module, pattern=self)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
flatten_node = tree_match.get_node(0)
wild_nodes = tree_match.get_node(1, is_wildcard=True).nodes
linear = resolve_module(tree_match.get_node(2), nn.Linear)
reshape_insert = _get_reshape_insert(flatten_node, linear)
conv = _linear_to_conv1x1(linear)
return replace_tree(
pattern_match, [*wild_nodes, reshape_insert, conv, flatten_node]
)
def _is_identity_adaptive_avg_pool(node: fx.Node) -> bool:
"""Return ``True`` when ``node`` is an identity ``AdaptiveAvgPool2d``.
The pool is identity when its ``output_size`` matches the spatial
dimensions of the input tensor. Requires shape metadata (``tensor_meta``)
on the input node.
``None`` in ``output_size`` means "keep the input dimension", which is
always identity for that axis.
"""
mod = get_module(node)
if not isinstance(mod, nn.AdaptiveAvgPool2d):
return False
if not _has_input_shape(node):
msg = f"Skipping {node.target}: missing shape metadata on input node"
_LOG.warning("%s", msg)
return False
assert isinstance(node.args[0], fx.Node)
meta = node.args[0].meta["tensor_meta"]
if len(meta.shape) != 4:
return False
in_h: int = meta.shape[2]
in_w: int = meta.shape[3]
if isinstance(mod.output_size, tuple):
out_h, out_w = mod.output_size
else:
out_h = out_w = mod.output_size
out_h = out_h or in_h
out_w = out_w or in_w
return (in_h == out_h) and (in_w == out_w)
[docs]
class RemoveIdentityAdaptiveAvgPoolPattern(Pattern):
"""Remove ``AdaptiveAvgPool2d`` where output size equals input size.
When ``output_size == (H, W)`` of the incoming feature map the pooling
operation is a mathematical identity and can be safely erased. This is
common in ConvNeXt-style architectures.
Assumes shapes have already been propagated into
``node.meta['tensor_meta']`` (e.g. via a prior ``ShapeProp`` pass). Nodes
with missing shape metadata are skipped with a warning.
"""
is_conversion = True
tree: Tree = (_is_identity_adaptive_avg_pool,)
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(graph_module, pattern=self)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
return replace_tree(pattern_match, [])
def _is_identity_passthrough(node: fx.Node) -> bool:
"""Return ``True`` when `node` is an identity operation that passes through its input."""
mod = get_module(node)
return isinstance(mod, nn.Identity)
[docs]
class RemoveIdentityPattern(Pattern):
"""Remove ``nn.Identity`` modules from the graph.
``nn.Identity`` is a no-op module that passes its input through unchanged.
These operations can be safely removed from the graph without affecting
model behavior. This simplifies the graph for downstream optimization
and fusion patterns.
"""
is_conversion = True
tree: Tree = (_is_identity_passthrough,)
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(graph_module, pattern=self)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
return replace_tree(pattern_match, [])
def _is_supported_mha(node: fx.Node) -> bool:
"""Return ``True`` when ``node`` is a supported ``MultiheadAttention``.
Only self-attention (``_qkv_same_embed_dim=True``, ``batch_first=True``)
without masks or ``add_zero_attn`` is supported. Unsupported
configurations are skipped with a warning.
"""
mod = get_module(node)
if not isinstance(mod, nn.MultiheadAttention):
return False
warnings: list[str] = []
# pylint: disable-next=protected-access
if not mod._qkv_same_embed_dim:
warnings.append(
f"Skipping MHA at {node.target}: cross-attention "
f"(_qkv_same_embed_dim=False) not supported"
)
if not mod.batch_first:
warnings.append(
f"Skipping MHA at {node.target}: batch_first=False not supported"
)
if mod.add_zero_attn:
warnings.append(
f"Skipping MHA at {node.target}: add_zero_attn=True not supported"
)
if len(node.args) > 3:
warnings.append(
f"Skipping MHA at {node.target}: extra positional args "
f"(possible mask arguments)"
)
for mask_key in ("attn_mask", "key_padding_mask"):
val = node.kwargs.get(mask_key)
if val is not None:
warnings.append(
f"Skipping MHA at {node.target}: {mask_key} is not None"
)
if warnings:
for msg in warnings:
_LOG.warning("%s", msg)
return False
return True
def _decompose_mha(
mha_node: fx.Node,
) -> tuple[MHAInProjection, ScaledDotProductAttention, nn.Linear]:
"""Split `mha_node` into in-proj, attention and out-proj modules."""
mha: nn.MultiheadAttention = resolve_module(
mha_node, nn.MultiheadAttention
)
embed_dim = mha.embed_dim
num_heads = mha.num_heads
head_dim = embed_dim // num_heads
linear = nn.Linear(
embed_dim,
3 * embed_dim,
bias=mha.in_proj_bias is not None,
)
linear.weight = mha.in_proj_weight
if mha.in_proj_bias is not None:
linear.bias = mha.in_proj_bias
in_proj = MHAInProjection(linear, num_heads, head_dim)
attention = ScaledDotProductAttention(
num_heads,
head_dim,
dropout=mha.dropout,
)
return in_proj, attention, mha.out_proj
def _unwrap_mha_getitems(
mha_node: fx.Node,
) -> None:
"""Replace ``getitem`` users of `mha_node` with direct references.
``MultiheadAttention`` returns ``(output, weights)``. ``getitem(mha, 0)``
users are rewired to reference `mha_node` directly and ``getitem(mha, 1)``
users are replaced with a ``None`` constant. After this call `mha_node`
behaves as a single-output node, allowing
:func:`~embedl_deploy._internal.core.replace.replace_tree` to handle it
normally.
"""
graph_module = mha_node.graph.owning_module
assert isinstance(graph_module, fx.GraphModule)
replaced_nodes = get_replaced_nodes(graph_module)
resolved = replaced_nodes.get(mha_node, mha_node)
for user in list(resolved.users.keys()):
if user.op == "call_function" and user.target is operator.getitem:
idx = user.args[1]
if idx == 0:
user.replace_all_uses_with(resolved)
graph_module.graph.erase_node(user)
elif idx == 1:
with graph_module.graph.inserting_after(resolved):
none_node = graph_module.graph.call_function(
operator.getitem, ((None,), 0)
)
user.replace_all_uses_with(none_node)
graph_module.graph.erase_node(user)
def _get_attention_insert(
attention: ScaledDotProductAttention,
) -> ReplacementFn:
"""Return a replacement function for Q/K/V splitting.
Splits the in-projection output into Q/K/V via ``getitem`` and feeds them
into `attention`.
"""
def _insert(
graph_module: fx.GraphModule,
prev_args: tuple[fx.Node, ...],
) -> list[fx.Node]:
name = get_auto_name(graph_module, attention)
graph_module.add_module(name, attention)
(proj_node,) = prev_args
graph = graph_module.graph
with graph.inserting_after(proj_node):
q = graph.call_function(operator.getitem, (proj_node, 0))
with graph.inserting_after(q):
k = graph.call_function(operator.getitem, (proj_node, 1))
with graph.inserting_after(k):
v = graph.call_function(operator.getitem, (proj_node, 2))
with graph.inserting_after(v):
attn = graph.call_module(name, (q, k, v))
return [q, k, v, attn]
return _insert
[docs]
class DecomposeMultiheadAttentionPattern(Pattern):
"""Decompose ``nn.MultiheadAttention`` into explicit sub-modules.
Replaces each ``MultiheadAttention`` node with three sub-modules visible in
the FX graph:
1. :class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection`
2. :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`
3. ``nn.Linear`` (output projection, reused from the original MHA)
Only self-attention (``_qkv_same_embed_dim=True``, ``batch_first=True``)
without masks is supported. Unsupported configurations are skipped with a
warning.
"""
is_conversion = True
tree: Tree = (_is_supported_mha,)
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(graph_module, pattern=self)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
mha_node = tree_match.get_node(0)
in_proj, attention, out_proj = _decompose_mha(mha_node)
attention_insert = _get_attention_insert(attention)
_unwrap_mha_getitems(mha_node)
return replace_tree(
pattern_match, [in_proj, attention_insert, out_proj]
)
# -- Swin Transformer attention decomposition -------------------------
def _is_shifted_window_attention(node: fx.Node) -> bool:
"""Return ``True`` for a supported ``shifted_window_attention`` call.
Only Swin V1 is supported. V2 nodes (which carry a ``logit_scale``
kwarg) are rejected. Returns ``False`` when *torchvision* is not
installed.
"""
if shifted_window_attention is None:
return False
if node.op != "call_function":
return False
if node.target is not shifted_window_attention:
return False
if node.kwargs.get("logit_scale") is not None:
_LOG.warning(
"Skipping %s: Swin V2 (logit_scale) not supported",
node.name,
)
return False
return True
def _resolve_get_attr(
graph_module: fx.GraphModule,
node: fx.Node,
) -> torch.Tensor | nn.Parameter:
"""Follow a ``get_attr`` node to the actual tensor."""
obj: object = graph_module
for part in node.target.split("."):
obj = getattr(obj, part)
assert isinstance(obj, torch.Tensor)
return obj
def _decompose_swin_attention( # pylint: disable=too-many-locals
swa_node: fx.Node,
graph_module: fx.GraphModule,
) -> tuple[
SwinWindowPartition,
MHAInProjection,
SwinAttention,
nn.Linear,
SwinWindowReverse,
]:
"""Split a ``shifted_window_attention`` call into five sub-modules.
Extracts weights and parameters from the ``get_attr`` nodes in the
function call's arguments, since the ``fx.GraphModule`` stores
traced-through modules as generic ``nn.Module`` wrappers.
"""
# Positional args.
qkv_weight = _resolve_get_attr(graph_module, swa_node.args[1])
proj_weight = _resolve_get_attr(graph_module, swa_node.args[2])
window_size: list[int] = list(swa_node.args[4])
num_heads: int = swa_node.args[5]
# Keyword args.
shift_size: list[int] = list(swa_node.kwargs["shift_size"])
attention_dropout: float = swa_node.kwargs["attention_dropout"]
qkv_bias_node = swa_node.kwargs.get("qkv_bias")
proj_bias_node = swa_node.kwargs.get("proj_bias")
embed_dim = qkv_weight.shape[1]
head_dim = embed_dim // num_heads
# Relative position bias from the _get_relative_position_bias call.
rel_bias_node = swa_node.args[3]
bias_table = _resolve_get_attr(graph_module, rel_bias_node.args[0])
bias_index = _resolve_get_attr(graph_module, rel_bias_node.args[1])
# Reconstruct nn.Linear for QKV projection.
has_qkv_bias = isinstance(qkv_bias_node, fx.Node)
qkv_linear = nn.Linear(embed_dim, 3 * embed_dim, bias=has_qkv_bias)
qkv_linear.weight = nn.Parameter(qkv_weight)
if has_qkv_bias:
qkv_linear.bias = nn.Parameter(
_resolve_get_attr(graph_module, qkv_bias_node),
)
# Reconstruct nn.Linear for output projection.
has_proj_bias = isinstance(proj_bias_node, fx.Node)
out_proj = nn.Linear(embed_dim, embed_dim, bias=has_proj_bias)
out_proj.weight = nn.Parameter(proj_weight)
if has_proj_bias:
out_proj.bias = nn.Parameter(
_resolve_get_attr(graph_module, proj_bias_node),
)
spatial_state = SwinSpatialState()
win_part = SwinWindowPartition(
window_size,
shift_size,
spatial_state,
)
in_proj = MHAInProjection(qkv_linear, num_heads, head_dim)
attention = SwinAttention(
relative_position_bias_table=nn.Parameter(bias_table),
relative_position_index=bias_index,
window_size=window_size,
num_heads=num_heads,
head_dim=head_dim,
attention_dropout=attention_dropout,
spatial_state=spatial_state,
)
win_rev = SwinWindowReverse(window_size, spatial_state)
return win_part, in_proj, attention, out_proj, win_rev
def _get_swin_replace(
win_part: SwinWindowPartition,
in_proj: MHAInProjection,
attention: SwinAttention,
out_proj: nn.Linear,
win_rev: SwinWindowReverse,
) -> ReplacementFn:
"""Return a replacement function for the full Swin decomposition."""
def _insert( # pylint: disable=too-many-locals
graph_module: fx.GraphModule,
prev_args: tuple[fx.Node, ...],
) -> list[fx.Node]:
# prev_args contains all fx.Node positional args of the
# shifted_window_attention call. Only the first (input) is
# used; the rest (weight get_attr, rel-bias node) are superseded
# by the new modules.
input_node = prev_args[0]
graph = graph_module.graph
# 1. SwinWindowPartition
wp_name = get_auto_name(graph_module, win_part)
graph_module.add_module(wp_name, win_part)
with graph.inserting_after(input_node):
wp_node = graph.call_module(wp_name, (input_node,))
# 2. MHAInProjection (self-attention: q=k=v)
ip_name = get_auto_name(graph_module, in_proj)
graph_module.add_module(ip_name, in_proj)
with graph.inserting_after(wp_node):
ip_node = graph.call_module(
ip_name,
(wp_node, wp_node, wp_node),
)
# 3. getitem nodes for Q, K, V
with graph.inserting_after(ip_node):
q = graph.call_function(
operator.getitem,
(ip_node, 0),
)
with graph.inserting_after(q):
k = graph.call_function(
operator.getitem,
(ip_node, 1),
)
with graph.inserting_after(k):
v = graph.call_function(
operator.getitem,
(ip_node, 2),
)
# 4. SwinAttention
at_name = get_auto_name(graph_module, attention)
graph_module.add_module(at_name, attention)
with graph.inserting_after(v):
at_node = graph.call_module(at_name, (q, k, v))
# 5. nn.Linear (output projection)
op_name = get_auto_name(graph_module, out_proj)
graph_module.add_module(op_name, out_proj)
with graph.inserting_after(at_node):
op_node = graph.call_module(op_name, (at_node,))
# 6. SwinWindowReverse
wr_name = get_auto_name(graph_module, win_rev)
graph_module.add_module(wr_name, win_rev)
with graph.inserting_after(op_node):
wr_node = graph.call_module(wr_name, (op_node,))
return [
wp_node,
ip_node,
q,
k,
v,
at_node,
op_node,
wr_node,
]
return _insert
def _cleanup_swin_dead_nodes(
graph_module: fx.GraphModule,
swa_node: fx.Node,
) -> None:
"""Erase orphaned ``get_attr`` and ``_get_relative_position_bias`` nodes.
After :func:`replace_tree` erases the ``shifted_window_attention``
node, its input ``get_attr`` nodes and the preceding
``_get_relative_position_bias`` call may have no remaining users.
"""
graph = graph_module.graph
# Collect candidate dead nodes: positional args + kwargs values.
candidates: list[fx.Node] = [
a for a in swa_node.args if isinstance(a, fx.Node)
]
candidates.extend(
v for v in swa_node.kwargs.values() if isinstance(v, fx.Node)
)
# The _get_relative_position_bias node's inputs are also candidates.
rel_bias_node = swa_node.args[3]
if isinstance(rel_bias_node, fx.Node):
candidates.extend(
a for a in rel_bias_node.args if isinstance(a, fx.Node)
)
# Erase nodes with no remaining users (rel-bias first, then attrs).
for node in candidates:
if len(node.users) == 0:
graph.erase_node(node)
class DecomposeSwinAttentionPattern(Pattern):
"""Decompose ``shifted_window_attention`` into explicit sub-modules.
Matches the ``call_function`` node produced when
``torchvision.models.swin_transformer.ShiftedWindowAttention`` is
FX-traced (``shifted_window_attention`` is ``torch.fx.wrap``-ped).
Replaces each node with five sub-modules:
1. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinWindowPartition`
2. :class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection` (reused)
3. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinAttention`
4. ``nn.Linear`` (output projection, reused from the original)
5. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinWindowReverse`
Only Swin V1 (``ShiftedWindowAttention``) is supported. V2 nodes
that carry a ``logit_scale`` keyword argument are skipped.
"""
is_conversion = True
tree: Tree = (_is_shifted_window_attention,)
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(graph_module, pattern=self)
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
swa_node = tree_match.get_node(0)
graph_module = pattern_match.graph_module
modules = _decompose_swin_attention(swa_node, graph_module)
swin_replace = _get_swin_replace(*modules)
nodes = replace_tree(pattern_match, [swin_replace])
_cleanup_swin_dead_nodes(graph_module, swa_node)
return nodes