# Copyright (C) 2026 Embedl AB
"""TFLite invoker component.
Supports inference via:
- **``qai_hub``** devices: Qualcomm AI Hub cloud inference using the
TensorFlow Lite runtime.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import qai_hub as hub
from embedl_hub._internal.core.compile.tflite import TFLiteCompiledModel
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import ComponentOutput
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.invoke import InvokeError
from embedl_hub._internal.core.utils.qai_hub_utils import (
log_runtime_info,
resolve_qai_hub_device,
)
from embedl_hub._internal.tracking.rest_api import RunType
from embedl_hub._internal.tracking.run_log import LoggedArtifact
# ---------------------------------------------------------------------------
# Output dataclass
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class TFLiteInvocationResult(ComponentOutput):
"""Output of a TFLite inference step.
:param output: Dictionary mapping output tensor names to numpy arrays.
:param output_file: The logged output ``.npz`` artifact.
"""
output: dict[str, np.ndarray]
output_file: LoggedArtifact
# ---------------------------------------------------------------------------
# Component class
# ---------------------------------------------------------------------------
[docs]
class TFLiteInvoker(Component):
"""Run inference on a compiled TFLite model.
Dispatches to a device-specific implementation based on the
configured device type.
"""
run_type = RunType.INFERENCE
def __init__(
self,
*,
name: str | None = None,
device: str | None = None,
use_compiled_names: bool = False,
) -> None:
"""Create a new TFLite invoker.
Parameters passed here become the defaults for every call.
They can be overridden per call by passing them to
:meth:`run` directly.
:param name: Display name for this component instance.
:param device: Name of the target device.
:param use_compiled_names: If ``True``, assume input names
already match the compiled model's tensor names and skip
automatic remapping.
"""
super().__init__(
name=name,
device=device,
use_compiled_names=use_compiled_names,
)
[docs]
def run(
self,
ctx: HubContext,
model: TFLiteCompiledModel,
input_data: dict[str, np.ndarray],
*,
device: str | None = None,
use_compiled_names: bool = False,
) -> TFLiteInvocationResult:
"""Run inference on a compiled TFLite model.
Keyword arguments override the defaults set in the constructor.
If a keyword argument is not provided here, the value from the
constructor is used.
:param ctx: The execution context with device configuration.
:param model: The compiled TFLite model (from :class:`TFLiteCompiler`).
:param input_data: Dictionary mapping input tensor names to numpy
arrays.
:param device: Name of the target device.
:param use_compiled_names: If ``False`` (default), input names are
remapped to the compiled model's names before submission and
output names are mapped back to the originals.
:return: A :class:`TFLiteInvocationResult` with inference output.
"""
raise NoProviderError # Replaced by the component system.
# --------------------------------------------------------------------- #
# QAI Hub provider
# --------------------------------------------------------------------- #
@TFLiteInvoker.provider(ProviderType.QAI_HUB)
def _invoke_tflite_qai_hub(
ctx: HubContext,
model: TFLiteCompiledModel,
input_data: dict[str, np.ndarray],
*,
device: str | None = None,
use_compiled_names: bool = False,
) -> TFLiteInvocationResult:
"""Run inference on a TFLite model using Qualcomm AI Hub.
:param ctx: The execution context with device configuration.
:param model: The compiled TFLite model.
:param input_data: Dictionary mapping input tensor names to numpy arrays.
:param device: Name of the target device.
:param use_compiled_names: If ``False`` (default), input names are
remapped to the compiled model's names before submission and
output names are mapped back to the originals.
:return: A :class:`TFLiteInvocationResult` with inference output.
"""
# -- resolve model path (auto-transfer to local) -------------------
model_local_path = model.path.to_local(ctx).local()
input_name_mapping = model.input_name_mapping
output_name_mapping = model.output_name_mapping
# -- validate device ------------------------------------------------
hub_device, device_name = resolve_qai_hub_device(ctx, device)
# -- remap input names if needed ------------------------------------
if not use_compiled_names and input_name_mapping:
input_data = _remap_input_names(input_data, input_name_mapping)
# -- log input artifact ---------------------------------------------
ctx.client.log_artifact(
model.path,
name="input",
file_name=f"input_{model.path.file_name}",
ctx=ctx,
)
# -- invoke via QAI Hub --------------------------------------------
output_data = _invoke_via_qai_hub(
ctx, model_local_path, hub_device, input_data
)
# -- remap output names back to originals ---------------------------
if not use_compiled_names and output_name_mapping:
reverse_output = {v: k for k, v in output_name_mapping.items()}
output_data = {
reverse_output.get(name, name): array
for name, array in output_data.items()
}
# -- save and log output artifact -----------------------------------
output_npz_path = ctx.artifact_dir / "output.npz"
np.savez_compressed(output_npz_path, **output_data)
ctx.client.log_artifact(output_npz_path, name="output_file")
return TFLiteInvocationResult.from_current_run(ctx, output=output_data)
# -- QAI Hub helpers -------------------------------------------------------
def _remap_input_names(
input_data: dict[str, np.ndarray],
name_mapping: dict[str, str],
) -> dict[str, np.ndarray]:
"""Remap input tensor names from original to compiled names.
:param input_data: Original input data keyed by pre-compile names.
:param name_mapping: Mapping from original to compiled names.
:return: New dict with keys replaced according to `name_mapping`.
:raises InvokeError: If an input name is not found in the mapping
and does not match any compiled name either.
"""
compiled_names = set(name_mapping.values())
remapped: dict[str, np.ndarray] = {}
for name, array in input_data.items():
if name in name_mapping:
remapped[name_mapping[name]] = array
elif name in compiled_names:
# Already using compiled names — pass through.
remapped[name] = array
else:
raise InvokeError(
f"Input name '{name}' not found in the name mapping "
f"and is not a compiled model input name. Known "
f"original names: {sorted(name_mapping.keys())}. "
f"Compiled names: {sorted(compiled_names)}."
)
return remapped
def _invoke_via_qai_hub(
ctx: HubContext,
model_path: Path,
hub_device: hub.Device,
input_data: dict[str, np.ndarray],
) -> dict[str, np.ndarray]:
"""Submit an inference job to Qualcomm AI Hub and download results.
:param ctx: The execution context.
:param model_path: Local path to the ``.tflite`` model.
:param hub_device: Target QAI Hub device.
:param input_data: Dictionary mapping input tensor names to numpy arrays.
:return: Dictionary mapping output tensor names to numpy arrays.
:raises InvokeError: If the inference job fails.
"""
hub_inputs: dict[str, list[np.ndarray]] = {
name: [array] for name, array in input_data.items()
}
try:
job: hub.InferenceJob = hub.submit_inference_job(
model=model_path.as_posix(),
device=hub_device,
inputs=hub_inputs,
)
except Exception as error:
raise InvokeError(
f"Failed to submit inference job to Qualcomm AI Hub: {error}"
) from error
ctx.client.log_param("$qai_hub_job_id", job.job_id)
ctx.client.log_param("qai_hub_inference_job_id", job.job_id)
try:
output_dataset = job.download_output_data()
except Exception as error:
raise InvokeError(
f"Failed to download inference output from Qualcomm AI Hub: {error}"
) from error
if output_dataset is None:
raise InvokeError(
f"Inference job returned no output data. Job ID: {job.job_id}"
)
log_runtime_info(ctx.client, job.job_id, error_class=InvokeError)
output_data: dict[str, np.ndarray] = {}
for name, arrays in output_dataset.items():
if isinstance(arrays, list) and len(arrays) > 0:
output_data[name] = arrays[0]
else:
output_data[name] = np.asarray(arrays)
return output_data