Source code for embedl_hub._internal.core.compile.tflite

# Copyright (C) 2026 Embedl AB

"""TFLite compiler component.

Supports compilation via:
- **``local``**: Local ONNX → TFLite conversion using ``onnx2tf`` and
  TensorFlow Lite.  No remote device required.
- **``qai_hub``** devices: Qualcomm AI Hub cloud compilation targeting
  the TensorFlow Lite runtime.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

import numpy as np
import qai_hub as hub
from qai_hub.client import CompileJob, QuantizeJob

from embedl_hub._internal.core.compile.result import CompileError
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import CompiledModel
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.utils.calibration_data import (
    generate_random_calibration_data,
    load_onnx_calibration_data,
)
from embedl_hub._internal.core.utils.onnx_utils import load_onnx_model
from embedl_hub._internal.core.utils.qai_hub_utils import (
    log_runtime_info,
    resolve_qai_hub_device,
    save_qai_hub_tflite_model,
)
from embedl_hub._internal.tracking.rest_api import RunType
from embedl_hub._internal.tracking.run_log import LoggedArtifact

CalibrationData = Path | dict[str, np.ndarray]


# ---------------------------------------------------------------------------
# Output dataclass
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class TFLiteCompiledModel(CompiledModel): """Output of a TFLite compilation step. Extends :class:`~embedl_hub._internal.core.component.output.CompiledModel` with optional input/output tensor name mappings that record how the compiler renamed tensors. """ input_name_mapping: dict[str, str] | None = None output_name_mapping: dict[str, str] | None = None
# --------------------------------------------------------------------------- # Component class # ---------------------------------------------------------------------------
[docs] class TFLiteCompiler(Component): """Compile an ONNX model to TFLite format. Supports both local compilation (via ``onnx2tf``) and cloud compilation via Qualcomm AI Hub. """ run_type = RunType.COMPILE def __init__( self, *, name: str | None = None, device: str | None = None, input_shape: tuple[int, ...] | None = None, calibration_data: CalibrationData | None = None, quantize_io: bool = False, ) -> None: """Create a new TFLite compiler. Parameters passed here become the defaults for every :meth:`run` call. They can be overridden per call by passing them to :meth:`run` directly. :param name: Display name for this compiler instance. Defaults to the class name. :param device: Name of the target device. :param input_shape: Input shape tuple. *Only used on* ``qai_hub`` *devices.* :param calibration_data: Calibration data for quantization. *Only used on* ``qai_hub`` *devices.* :param quantize_io: If ``True``, quantize model I/O tensors. *Only used on* ``qai_hub`` *devices.* """ super().__init__( name=name, device=device, input_shape=input_shape, calibration_data=calibration_data, quantize_io=quantize_io, )
[docs] def run( self, ctx: HubContext, onnx_path: Path, *, device: str | None = None, input_shape: tuple[int, ...] | None = None, calibration_data: CalibrationData | None = None, quantize_io: bool = False, ) -> TFLiteCompiledModel: """Compile an ONNX model to TFLite. Keyword arguments override the defaults set in the constructor. If a keyword argument is not provided here, the constructor default is used. :param ctx: The execution context with device configuration. :param onnx_path: Path to the input ONNX model. :param device: Name of the target device. :param input_shape: Input shape tuple. *Only used on* ``qai_hub`` *devices.* :param calibration_data: Calibration data for quantization. *Only used on* ``qai_hub`` *devices.* :param quantize_io: If ``True``, quantize model I/O tensors. *Only used on* ``qai_hub`` *devices.* :return: A :class:`TFLiteCompiledModel` with the path to the compiled model. """ raise NoProviderError # Replaced by the component system.
# --------------------------------------------------------------------- # # Local provider # --------------------------------------------------------------------- # @TFLiteCompiler.provider(ProviderType.LOCAL) def _compile_tflite_local( ctx: HubContext, onnx_path: Path, *, device: str | None = None, input_shape: tuple[int, ...] | None = None, calibration_data: CalibrationData | None = None, quantize_io: bool = False, ) -> TFLiteCompiledModel: """Compile an ONNX model to TFLite locally using ``onnx2tf``. This provider runs entirely on the local machine without any remote device. FP16 quantization is applied by default. :param ctx: The execution context. :param onnx_path: Path to the input ONNX model. :param device: Ignored (no remote device). :param input_shape: Ignored. :param calibration_data: Ignored (local conversion uses FP16 only). :param quantize_io: Ignored. :return: A :class:`TFLiteCompiledModel` with the path to the compiled model. """ from embedl_hub._internal.core.compile.onnx_to_tf import ( compile_onnx_to_tflite, ) output_path = ctx.artifact_dir / f"compiled_{onnx_path.stem}.tflite" result = compile_onnx_to_tflite( onnx_model_path=onnx_path, output_model_path=output_path, fp16=True, ) if ctx.client is not None: ctx.client.log_artifact( onnx_path, name="input", file_name=f"input_{onnx_path.name}" ) ctx.client.log_artifact(result.model_path, name="path") return TFLiteCompiledModel.from_current_run(ctx) # No tracking client — return a minimal output. from datetime import UTC, datetime now = datetime.now(UTC) return TFLiteCompiledModel( artifact_dir=ctx.artifact_dir, devices={}, run_log=None, input=LoggedArtifact( id="", file_name=onnx_path.name, file_size=onnx_path.stat().st_size if onnx_path.exists() else 0, logged_at=now, file_path=onnx_path, name="input", ), path=LoggedArtifact( id="", file_name=result.model_path.name, file_size=( result.model_path.stat().st_size if result.model_path.exists() else 0 ), logged_at=now, file_path=result.model_path, name="path", ), ) # --------------------------------------------------------------------- # # QAI Hub provider # --------------------------------------------------------------------- # @TFLiteCompiler.provider(ProviderType.QAI_HUB) def _compile_tflite_qai_hub( ctx: HubContext, onnx_path: Path, *, device: str | None = None, input_shape: tuple[int, ...] | None = None, calibration_data: CalibrationData | None = None, quantize_io: bool = False, ) -> TFLiteCompiledModel: """Compile an ONNX model to TFLite using Qualcomm AI Hub. The model is first quantized to INT8 via ``hub.submit_quantize_job()``, then compiled for TensorFlow Lite via ``hub.submit_compile_job()``. :param ctx: The execution context with device configuration. :param onnx_path: Path to the input ONNX model. :param device: Name of the target device. :param input_shape: Input shape tuple. :param calibration_data: Calibration data for quantization. :param quantize_io: If ``True``, quantize I/O tensors. :return: A :class:`TFLiteCompiledModel` with the path to the compiled model. """ # Log the input model artifact ctx.client.log_artifact( onnx_path, name="input", file_name=f"input_{onnx_path.name}" ) hub_device, device_name = resolve_qai_hub_device(ctx, device) if input_shape: ctx.client.log_param("input_shape", "x".join(map(str, input_shape))) # Step 1: Quantize quantized_model_path = _quantize_via_qai_hub( ctx, onnx_path, calibration_data ) # Step 2: Compile to TFLite compiled_model_path = _compile_via_qai_hub( ctx, quantized_model_path, hub_device, input_shape=input_shape, quantize_io=quantize_io, output_name=f"compiled_{onnx_path.stem}.tflite", ) # Log the compiled model artifact ctx.client.log_artifact(compiled_model_path, name="path") # Build input/output name mappings by comparing the # pre-compile (quantized ONNX) and post-compile (TFLite) models # positionally. input_name_mapping, output_name_mapping = _build_name_mappings( quantized_model_path, compiled_model_path ) if input_name_mapping: ctx.client.log_param( "input_name_mapping", ", ".join(f"{k} -> {v}" for k, v in input_name_mapping.items()), ) if output_name_mapping: ctx.client.log_param( "output_name_mapping", ", ".join(f"{k} -> {v}" for k, v in output_name_mapping.items()), ) return TFLiteCompiledModel.from_current_run( ctx, input_name_mapping=input_name_mapping or None, output_name_mapping=output_name_mapping or None, ) # -- QAI Hub helpers ------------------------------------------------------- def _quantize_via_qai_hub( ctx: HubContext, onnx_path: Path, calibration_data: CalibrationData | None, ) -> Path: """Submit a quantization job to Qualcomm AI Hub. :param ctx: The execution context. :param onnx_path: Path to the floating-point ONNX model. :param calibration_data: Optional calibration data. :return: Local path to the quantized ONNX model. :raises CompileError: If the quantization job fails. """ # Prepare calibration dataset if isinstance(calibration_data, dict): calib_dataset: dict[str, list[np.ndarray]] = { name: list(array) for name, array in calibration_data.items() } ctx.client.log_param( "num_calibration_samples", str(min(len(v) for v in calib_dataset.values())), ) elif calibration_data is not None: calib_dataset = load_onnx_calibration_data( model_path=onnx_path, data_path=calibration_data, ) ctx.client.log_param( "num_calibration_samples", str(min(len(v) for v in calib_dataset.values())), ) else: calib_dataset = generate_random_calibration_data(onnx_path) ctx.client.log_param("num_calibration_samples", "1") # Submit quantization job try: quantize_job: QuantizeJob = hub.submit_quantize_job( model=onnx_path.as_posix(), calibration_data=calib_dataset, weights_dtype=hub.QuantizeDtype.INT8, activations_dtype=hub.QuantizeDtype.INT8, ) except Exception as error: raise CompileError("Failed to submit quantization job.") from error ctx.client.log_param("qai_hub_quantize_job_id", quantize_job.job_id) # Download quantized model try: quantized_model = quantize_job.get_target_model() except Exception as error: raise CompileError( "Failed to download quantized model from Qualcomm AI Hub." ) from error if quantized_model is None: raise CompileError("Quantization job did not produce a target model.") # Save quantized ONNX model (intermediate artifact) from embedl_hub._internal.core.utils.qai_hub_utils import ( save_qai_hub_model, ) quantized_path = save_qai_hub_model( quantized_model, ctx.artifact_dir / f"quantized_{onnx_path.name}", ) ctx.client.log_artifact( quantized_path, name="quantized", file_name=f"quantized_{onnx_path.name}", ) return quantized_path def _build_name_mappings( source_model_path: Path, compiled_model_path: Path, ) -> tuple[dict[str, str], dict[str, str]]: """Build input/output name mappings between an ONNX and a TFLite model. Compares the input and output tensor names of the source ONNX model and the compiled TFLite model **by position**. Only names that actually changed are included in the returned dictionaries. :param source_model_path: Path to the pre-compile ONNX model. :param compiled_model_path: Path to the compiled ``.tflite`` model. :return: A ``(input_mapping, output_mapping)`` tuple where each dict maps original names to compiled names. Empty dicts if no names changed. :raises CompileError: If the number of inputs or outputs differs between the two models. """ from embedl_hub._internal.core.utils.tflite_utils import ( get_tflite_model_input_names, get_tflite_model_output_names, ) source_model = load_onnx_model(source_model_path) src_input_names = [inp.name for inp in source_model.graph.input] tgt_input_names = get_tflite_model_input_names(str(compiled_model_path)) if len(src_input_names) != len(tgt_input_names): raise CompileError( f"Input count mismatch: source model has " f"{len(src_input_names)} input(s) but compiled model has " f"{len(tgt_input_names)}. Cannot build name mapping." ) src_output_names = [out.name for out in source_model.graph.output] tgt_output_names = get_tflite_model_output_names(str(compiled_model_path)) if len(src_output_names) != len(tgt_output_names): raise CompileError( f"Output count mismatch: source model has " f"{len(src_output_names)} output(s) but compiled model has " f"{len(tgt_output_names)}. Cannot build name mapping." ) input_mapping: dict[str, str] = {} for src_name, tgt_name in zip(src_input_names, tgt_input_names): if src_name != tgt_name: input_mapping[src_name] = tgt_name output_mapping: dict[str, str] = {} for src_name, tgt_name in zip(src_output_names, tgt_output_names): if src_name != tgt_name: output_mapping[src_name] = tgt_name return input_mapping, output_mapping def _compile_via_qai_hub( ctx: HubContext, model_path: Path, hub_device: hub.Device, *, input_shape: tuple[int, ...] | None, quantize_io: bool, output_name: str = "compiled_model.tflite", ) -> Path: """Submit a compile job to Qualcomm AI Hub targeting TFLite. :param ctx: The execution context. :param model_path: Path to the (quantized) ONNX model. :param hub_device: Target QAI Hub device. :param input_shape: Optional input shape override. :param quantize_io: Whether to quantize I/O tensors. :param output_name: Filename for the compiled model artifact. :return: Local path to the compiled TFLite model. :raises CompileError: If the compile job fails. """ # Build compile options — target TFLite runtime opts = "--target_runtime tflite" if quantize_io: opts += " --quantize_io" input_specs = {"image": input_shape} if input_shape else None # Submit compile job try: compile_job: CompileJob = hub.submit_compile_job( model=model_path.as_posix(), device=hub_device, options=opts, input_specs=input_specs, ) except Exception as error: raise CompileError("Failed to submit compile job.") from error ctx.client.log_param("$qai_hub_job_id", compile_job.job_id) ctx.client.log_param("qai_hub_compile_job_id", compile_job.job_id) # Download compiled model try: compiled_model = compile_job.get_target_model() except Exception as error: raise CompileError( "Failed to download compiled model from Qualcomm AI Hub." ) from error if compiled_model is None: raise CompileError("Compile job did not produce a target model.") compiled_model_path = save_qai_hub_tflite_model( compiled_model, ctx.artifact_dir / output_name ) # Parse and log runtime info log_runtime_info(ctx.client, compile_job.job_id, error_class=CompileError) return compiled_model_path