# Copyright (C) 2026 Embedl AB
"""TFLite compiler component.
Supports compilation via:
- **``local``**: Local ONNX → TFLite conversion using ``onnx2tf`` and
TensorFlow Lite. No remote device required.
- **``qai_hub``** devices: Qualcomm AI Hub cloud compilation targeting
the TensorFlow Lite runtime.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import qai_hub as hub
from qai_hub.client import CompileJob, QuantizeJob
from embedl_hub._internal.core.compile.result import CompileError
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import CompiledModel
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.utils.calibration_data import (
generate_random_calibration_data,
load_onnx_calibration_data,
)
from embedl_hub._internal.core.utils.onnx_utils import load_onnx_model
from embedl_hub._internal.core.utils.qai_hub_utils import (
log_runtime_info,
resolve_qai_hub_device,
save_qai_hub_tflite_model,
)
from embedl_hub._internal.tracking.rest_api import RunType
from embedl_hub._internal.tracking.run_log import LoggedArtifact
CalibrationData = Path | dict[str, np.ndarray]
# ---------------------------------------------------------------------------
# Output dataclass
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class TFLiteCompiledModel(CompiledModel):
"""Output of a TFLite compilation step.
Extends :class:`~embedl_hub._internal.core.component.output.CompiledModel`
with optional input/output tensor name mappings that record how
the compiler renamed tensors.
"""
input_name_mapping: dict[str, str] | None = None
output_name_mapping: dict[str, str] | None = None
# ---------------------------------------------------------------------------
# Component class
# ---------------------------------------------------------------------------
[docs]
class TFLiteCompiler(Component):
"""Compile an ONNX model to TFLite format.
Supports both local compilation (via ``onnx2tf``) and cloud
compilation via Qualcomm AI Hub.
"""
run_type = RunType.COMPILE
def __init__(
self,
*,
name: str | None = None,
device: str | None = None,
input_shape: tuple[int, ...] | None = None,
calibration_data: CalibrationData | None = None,
quantize_io: bool = False,
) -> None:
"""Create a new TFLite compiler.
Parameters passed here become the defaults for every
:meth:`run` call. They can be overridden per call by passing
them to :meth:`run` directly.
:param name: Display name for this compiler instance. Defaults
to the class name.
:param device: Name of the target device.
:param input_shape: Input shape tuple. *Only used on*
``qai_hub`` *devices.*
:param calibration_data: Calibration data for quantization.
*Only used on* ``qai_hub`` *devices.*
:param quantize_io: If ``True``, quantize model I/O tensors.
*Only used on* ``qai_hub`` *devices.*
"""
super().__init__(
name=name,
device=device,
input_shape=input_shape,
calibration_data=calibration_data,
quantize_io=quantize_io,
)
[docs]
def run(
self,
ctx: HubContext,
onnx_path: Path,
*,
device: str | None = None,
input_shape: tuple[int, ...] | None = None,
calibration_data: CalibrationData | None = None,
quantize_io: bool = False,
) -> TFLiteCompiledModel:
"""Compile an ONNX model to TFLite.
Keyword arguments override the defaults set in the constructor.
If a keyword argument is not provided here, the constructor
default is used.
:param ctx: The execution context with device configuration.
:param onnx_path: Path to the input ONNX model.
:param device: Name of the target device.
:param input_shape: Input shape tuple. *Only used on*
``qai_hub`` *devices.*
:param calibration_data: Calibration data for quantization.
*Only used on* ``qai_hub`` *devices.*
:param quantize_io: If ``True``, quantize model I/O tensors.
*Only used on* ``qai_hub`` *devices.*
:return: A :class:`TFLiteCompiledModel` with the path to the
compiled model.
"""
raise NoProviderError # Replaced by the component system.
# --------------------------------------------------------------------- #
# Local provider
# --------------------------------------------------------------------- #
@TFLiteCompiler.provider(ProviderType.LOCAL)
def _compile_tflite_local(
ctx: HubContext,
onnx_path: Path,
*,
device: str | None = None,
input_shape: tuple[int, ...] | None = None,
calibration_data: CalibrationData | None = None,
quantize_io: bool = False,
) -> TFLiteCompiledModel:
"""Compile an ONNX model to TFLite locally using ``onnx2tf``.
This provider runs entirely on the local machine without any remote
device. FP16 quantization is applied by default.
:param ctx: The execution context.
:param onnx_path: Path to the input ONNX model.
:param device: Ignored (no remote device).
:param input_shape: Ignored.
:param calibration_data: Ignored (local conversion uses FP16 only).
:param quantize_io: Ignored.
:return: A :class:`TFLiteCompiledModel` with the path to the
compiled model.
"""
from embedl_hub._internal.core.compile.onnx_to_tf import (
compile_onnx_to_tflite,
)
output_path = ctx.artifact_dir / f"compiled_{onnx_path.stem}.tflite"
result = compile_onnx_to_tflite(
onnx_model_path=onnx_path,
output_model_path=output_path,
fp16=True,
)
if ctx.client is not None:
ctx.client.log_artifact(
onnx_path, name="input", file_name=f"input_{onnx_path.name}"
)
ctx.client.log_artifact(result.model_path, name="path")
return TFLiteCompiledModel.from_current_run(ctx)
# No tracking client — return a minimal output.
from datetime import UTC, datetime
now = datetime.now(UTC)
return TFLiteCompiledModel(
artifact_dir=ctx.artifact_dir,
devices={},
run_log=None,
input=LoggedArtifact(
id="",
file_name=onnx_path.name,
file_size=onnx_path.stat().st_size if onnx_path.exists() else 0,
logged_at=now,
file_path=onnx_path,
name="input",
),
path=LoggedArtifact(
id="",
file_name=result.model_path.name,
file_size=(
result.model_path.stat().st_size
if result.model_path.exists()
else 0
),
logged_at=now,
file_path=result.model_path,
name="path",
),
)
# --------------------------------------------------------------------- #
# QAI Hub provider
# --------------------------------------------------------------------- #
@TFLiteCompiler.provider(ProviderType.QAI_HUB)
def _compile_tflite_qai_hub(
ctx: HubContext,
onnx_path: Path,
*,
device: str | None = None,
input_shape: tuple[int, ...] | None = None,
calibration_data: CalibrationData | None = None,
quantize_io: bool = False,
) -> TFLiteCompiledModel:
"""Compile an ONNX model to TFLite using Qualcomm AI Hub.
The model is first quantized to INT8 via ``hub.submit_quantize_job()``,
then compiled for TensorFlow Lite via ``hub.submit_compile_job()``.
:param ctx: The execution context with device configuration.
:param onnx_path: Path to the input ONNX model.
:param device: Name of the target device.
:param input_shape: Input shape tuple.
:param calibration_data: Calibration data for quantization.
:param quantize_io: If ``True``, quantize I/O tensors.
:return: A :class:`TFLiteCompiledModel` with the path to the
compiled model.
"""
# Log the input model artifact
ctx.client.log_artifact(
onnx_path, name="input", file_name=f"input_{onnx_path.name}"
)
hub_device, device_name = resolve_qai_hub_device(ctx, device)
if input_shape:
ctx.client.log_param("input_shape", "x".join(map(str, input_shape)))
# Step 1: Quantize
quantized_model_path = _quantize_via_qai_hub(
ctx, onnx_path, calibration_data
)
# Step 2: Compile to TFLite
compiled_model_path = _compile_via_qai_hub(
ctx,
quantized_model_path,
hub_device,
input_shape=input_shape,
quantize_io=quantize_io,
output_name=f"compiled_{onnx_path.stem}.tflite",
)
# Log the compiled model artifact
ctx.client.log_artifact(compiled_model_path, name="path")
# Build input/output name mappings by comparing the
# pre-compile (quantized ONNX) and post-compile (TFLite) models
# positionally.
input_name_mapping, output_name_mapping = _build_name_mappings(
quantized_model_path, compiled_model_path
)
if input_name_mapping:
ctx.client.log_param(
"input_name_mapping",
", ".join(f"{k} -> {v}" for k, v in input_name_mapping.items()),
)
if output_name_mapping:
ctx.client.log_param(
"output_name_mapping",
", ".join(f"{k} -> {v}" for k, v in output_name_mapping.items()),
)
return TFLiteCompiledModel.from_current_run(
ctx,
input_name_mapping=input_name_mapping or None,
output_name_mapping=output_name_mapping or None,
)
# -- QAI Hub helpers -------------------------------------------------------
def _quantize_via_qai_hub(
ctx: HubContext,
onnx_path: Path,
calibration_data: CalibrationData | None,
) -> Path:
"""Submit a quantization job to Qualcomm AI Hub.
:param ctx: The execution context.
:param onnx_path: Path to the floating-point ONNX model.
:param calibration_data: Optional calibration data.
:return: Local path to the quantized ONNX model.
:raises CompileError: If the quantization job fails.
"""
# Prepare calibration dataset
if isinstance(calibration_data, dict):
calib_dataset: dict[str, list[np.ndarray]] = {
name: list(array) for name, array in calibration_data.items()
}
ctx.client.log_param(
"num_calibration_samples",
str(min(len(v) for v in calib_dataset.values())),
)
elif calibration_data is not None:
calib_dataset = load_onnx_calibration_data(
model_path=onnx_path,
data_path=calibration_data,
)
ctx.client.log_param(
"num_calibration_samples",
str(min(len(v) for v in calib_dataset.values())),
)
else:
calib_dataset = generate_random_calibration_data(onnx_path)
ctx.client.log_param("num_calibration_samples", "1")
# Submit quantization job
try:
quantize_job: QuantizeJob = hub.submit_quantize_job(
model=onnx_path.as_posix(),
calibration_data=calib_dataset,
weights_dtype=hub.QuantizeDtype.INT8,
activations_dtype=hub.QuantizeDtype.INT8,
)
except Exception as error:
raise CompileError("Failed to submit quantization job.") from error
ctx.client.log_param("qai_hub_quantize_job_id", quantize_job.job_id)
# Download quantized model
try:
quantized_model = quantize_job.get_target_model()
except Exception as error:
raise CompileError(
"Failed to download quantized model from Qualcomm AI Hub."
) from error
if quantized_model is None:
raise CompileError("Quantization job did not produce a target model.")
# Save quantized ONNX model (intermediate artifact)
from embedl_hub._internal.core.utils.qai_hub_utils import (
save_qai_hub_model,
)
quantized_path = save_qai_hub_model(
quantized_model,
ctx.artifact_dir / f"quantized_{onnx_path.name}",
)
ctx.client.log_artifact(
quantized_path,
name="quantized",
file_name=f"quantized_{onnx_path.name}",
)
return quantized_path
def _build_name_mappings(
source_model_path: Path,
compiled_model_path: Path,
) -> tuple[dict[str, str], dict[str, str]]:
"""Build input/output name mappings between an ONNX and a TFLite model.
Compares the input and output tensor names of the source ONNX model
and the compiled TFLite model **by position**. Only names that
actually changed are included in the returned dictionaries.
:param source_model_path: Path to the pre-compile ONNX model.
:param compiled_model_path: Path to the compiled ``.tflite`` model.
:return: A ``(input_mapping, output_mapping)`` tuple where each dict
maps original names to compiled names. Empty dicts if no names
changed.
:raises CompileError: If the number of inputs or outputs differs
between the two models.
"""
from embedl_hub._internal.core.utils.tflite_utils import (
get_tflite_model_input_names,
get_tflite_model_output_names,
)
source_model = load_onnx_model(source_model_path)
src_input_names = [inp.name for inp in source_model.graph.input]
tgt_input_names = get_tflite_model_input_names(str(compiled_model_path))
if len(src_input_names) != len(tgt_input_names):
raise CompileError(
f"Input count mismatch: source model has "
f"{len(src_input_names)} input(s) but compiled model has "
f"{len(tgt_input_names)}. Cannot build name mapping."
)
src_output_names = [out.name for out in source_model.graph.output]
tgt_output_names = get_tflite_model_output_names(str(compiled_model_path))
if len(src_output_names) != len(tgt_output_names):
raise CompileError(
f"Output count mismatch: source model has "
f"{len(src_output_names)} output(s) but compiled model has "
f"{len(tgt_output_names)}. Cannot build name mapping."
)
input_mapping: dict[str, str] = {}
for src_name, tgt_name in zip(src_input_names, tgt_input_names):
if src_name != tgt_name:
input_mapping[src_name] = tgt_name
output_mapping: dict[str, str] = {}
for src_name, tgt_name in zip(src_output_names, tgt_output_names):
if src_name != tgt_name:
output_mapping[src_name] = tgt_name
return input_mapping, output_mapping
def _compile_via_qai_hub(
ctx: HubContext,
model_path: Path,
hub_device: hub.Device,
*,
input_shape: tuple[int, ...] | None,
quantize_io: bool,
output_name: str = "compiled_model.tflite",
) -> Path:
"""Submit a compile job to Qualcomm AI Hub targeting TFLite.
:param ctx: The execution context.
:param model_path: Path to the (quantized) ONNX model.
:param hub_device: Target QAI Hub device.
:param input_shape: Optional input shape override.
:param quantize_io: Whether to quantize I/O tensors.
:param output_name: Filename for the compiled model artifact.
:return: Local path to the compiled TFLite model.
:raises CompileError: If the compile job fails.
"""
# Build compile options — target TFLite runtime
opts = "--target_runtime tflite"
if quantize_io:
opts += " --quantize_io"
input_specs = {"image": input_shape} if input_shape else None
# Submit compile job
try:
compile_job: CompileJob = hub.submit_compile_job(
model=model_path.as_posix(),
device=hub_device,
options=opts,
input_specs=input_specs,
)
except Exception as error:
raise CompileError("Failed to submit compile job.") from error
ctx.client.log_param("$qai_hub_job_id", compile_job.job_id)
ctx.client.log_param("qai_hub_compile_job_id", compile_job.job_id)
# Download compiled model
try:
compiled_model = compile_job.get_target_model()
except Exception as error:
raise CompileError(
"Failed to download compiled model from Qualcomm AI Hub."
) from error
if compiled_model is None:
raise CompileError("Compile job did not produce a target model.")
compiled_model_path = save_qai_hub_tflite_model(
compiled_model, ctx.artifact_dir / output_name
)
# Parse and log runtime info
log_runtime_info(ctx.client, compile_job.job_id, error_class=CompileError)
return compiled_model_path