# Optimization Pipeline Embedl Deploy prepares PyTorch models for hardware deployment through a three-stage pipeline. Each stage operates on a standard `torch.fx.GraphModule` and produces a numerically equivalent model. ## Pipeline stages ``` ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ Conversions │ ──▶ │ Fusions │ ──▶ │ Quantization │ │ (structural) │ │ (operator) │ │ (INT8) │ └──────────────┘ └──────────────┘ └──────────────┘ ``` 1. **Conversions** — structural graph rewrites that normalize the computation graph for the target hardware. These are necessary because hardware compilers like TensorRT have specific operator requirements and fusion patterns. For example, replacing `Flatten → Linear` with `Conv2d(1×1) → Flatten` enables TensorRT to fuse the classifier head with preceding convolutions. Similarly, decomposing `nn.MultiheadAttention` into explicit sub-modules exposes the Q/K/V projections for proper QDQ placement. Conversions run iteratively until no new matches are found. 2. **Fusions** — combine sequences of operators into single fused modules that map directly to hardware-accelerated kernels. For example, `Conv2d → BatchNorm2d → ReLU` becomes a single `FusedConvBNReLU` module. Fusions are matched and applied in a single pass, with longer patterns taking priority over shorter ones. 3. **Quantization** — insert Quantize/DeQuantize (QDQ) stubs at pattern-declared positions, calibrate scale/zero-point from representative data, and optionally run Quantization-Aware Training (QAT). The result is an ONNX-exportable model with explicit QDQ nodes that hardware compilers (like TensorRT) use for INT8 inference. ## One-shot API The simplest way to run the pipeline is `transform()`: ```python import torch from torchvision.models import resnet50 from embedl_deploy import transform from embedl_deploy.tensorrt import TENSORRT_PATTERNS model = resnet50(weights="DEFAULT").eval() result = transform(model, patterns=TENSORRT_PATTERNS) deployable = result.model # fused GraphModule print(f"Applied: {result.report['applied_count']}") print(f"Skipped: {result.report['skipped_count']}") ``` `transform()` accepts both `nn.Module` and `fx.GraphModule`. When given an `nn.Module`, it traces internally with `torch.fx.symbolic_trace`. For models with dynamic control flow or custom tracing requirements, trace the model yourself and pass the `GraphModule` directly: ```python # Option 1: Let transform() handle tracing result = transform(model, patterns=TENSORRT_PATTERNS) # Option 2: Trace yourself for full control graph_module = torch.fx.symbolic_trace(model) result = transform(graph_module, patterns=TENSORRT_PATTERNS) ``` When `transform()` receives an `nn.Module`, it internally: 1. Traces the model with `torch.fx.symbolic_trace`. 2. Runs all conversion patterns (`is_conversion=True`) iteratively. 3. Runs all fusion patterns in a single pass. ## Plan-based API For full control, use the two-step workflow: ```python from embedl_deploy import get_transformation_plan, apply_transformation_plan from embedl_deploy.tensorrt import TENSORRT_PATTERNS graph_module = torch.fx.symbolic_trace(model) plan = get_transformation_plan(graph_module, patterns=TENSORRT_PATTERNS) # Inspect matches for node_name, pats in plan.matches.items(): for pat_name, match in pats.items(): print(f"{node_name}: {pat_name} (apply={match.apply})") # Disable specific matches plan.matches["maxpool"]["StemConvBNReLUMaxPoolPattern"].apply = False # Apply result = apply_transformation_plan(plan) ``` The plan is a `TransformationPlan` dataclass with: - `model` — deep copy of the traced graph (never modifies the original). - `matches` — nested dict: `node_name → pattern_name → PatternMatch`. Each `PatternMatch` has an `apply` flag. Set it to `False` to skip that match. ## Pattern priority Patterns are matched in list order. When multiple patterns match overlapping nodes, the first match claims the nodes and later matches are marked `apply=False`. This is why pattern lists are ordered longest-first: ```python TENSORRT_FUSION_PATTERNS = [ StemConvBNReLUMaxPoolPattern(), # 4 nodes ConvBNAddReLUPattern(), # 4 nodes (branching) ConvBNReLUPattern(), # 3 nodes LinearReLUPattern(), # 2 nodes ConvBNPattern(), # 2 nodes LinearPattern(), # 1 node LayerNormPattern(), # 1 node AdaptiveAvgPoolPattern(), # 1 node MHAInProjectionPattern(), # 1 node ScaledDotProductAttentionPattern(), # 1 node ] ``` For ResNet50, `StemConvBNReLUMaxPoolPattern` claims the stem's `Conv → BN → ReLU → MaxPool` nodes. The shorter `ConvBNReLUPattern` also matches the first three of those nodes, but they are already consumed — so it is automatically skipped for that subgraph. ## Pattern groups The TensorRT backend exposes pre-built pattern lists: | List | Contents | |------|----------| | `TENSORRT_CONVERSION_PATTERNS` | Structural conversions (run first) | | `TENSORRT_FUSION_PATTERNS` | Operator fusions | | `TENSORRT_QUANTIZED_PATTERNS` | Quantization-focused rewrites | | `TENSORRT_PATTERNS` | Union of all the above | ```python from embedl_deploy.tensorrt import ( TENSORRT_CONVERSION_PATTERNS, TENSORRT_FUSION_PATTERNS, TENSORRT_PATTERNS, TENSORRT_QUANTIZED_PATTERNS, ) # Run only fusions (skip conversions) result = transform(model, patterns=TENSORRT_FUSION_PATTERNS) # Run only conversions result = transform(model, patterns=TENSORRT_CONVERSION_PATTERNS) ``` ## Verifying numerical equivalence Conversions and fusions must not change model outputs. Always verify: ```python with torch.no_grad(): y_original = model(example_input) y_deployed = result.model(example_input) max_diff = (y_original - y_deployed).abs().max().item() assert max_diff < 1e-5, f"Numerical mismatch: {max_diff}" ``` ## ONNX export and compilation After transformation, export and compile as usual: ```python torch.onnx.export( result.model.cpu().eval(), torch.randn(1, 3, 224, 224), "model_fused.onnx", opset_version=20, input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}, ) ``` Then compile with TensorRT: ```bash trtexec \ --onnx=model_fused.onnx \ --fp16 \ --exportLayerInfo=layer_info.json \ --profilingVerbosity=detailed ```