Source code for embedl_hub._internal.core.invoke.tflite

# Copyright (C) 2026 Embedl AB

"""TFLite invoker component.

Supports inference via:
- **``qai_hub``** devices: Qualcomm AI Hub cloud inference using the
  TensorFlow Lite runtime.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

import numpy as np
import qai_hub as hub

from embedl_hub._internal.core.compile.tflite import TFLiteCompiledModel
from embedl_hub._internal.core.component.abc import Component, NoProviderError
from embedl_hub._internal.core.component.output import ComponentOutput
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.core.invoke import InvokeError
from embedl_hub._internal.core.utils.qai_hub_utils import (
    log_runtime_info,
    resolve_qai_hub_device,
)
from embedl_hub._internal.tracking.rest_api import RunType
from embedl_hub._internal.tracking.run_log import LoggedArtifact

# ---------------------------------------------------------------------------
# Output dataclass
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class TFLiteInvocationResult(ComponentOutput): """Output of a TFLite inference step. :param output: Dictionary mapping output tensor names to numpy arrays. :param output_file: The logged output ``.npz`` artifact. """ output: dict[str, np.ndarray] output_file: LoggedArtifact
# --------------------------------------------------------------------------- # Component class # ---------------------------------------------------------------------------
[docs] class TFLiteInvoker(Component): """Run inference on a compiled TFLite model. Dispatches to a device-specific implementation based on the configured device type. """ run_type = RunType.INFERENCE def __init__( self, *, name: str | None = None, device: str | None = None, use_compiled_names: bool = False, ) -> None: """Create a new TFLite invoker. Parameters passed here become the defaults for every call. They can be overridden per call by passing them to :meth:`run` directly. :param name: Display name for this component instance. :param device: Name of the target device. :param use_compiled_names: If ``True``, assume input names already match the compiled model's tensor names and skip automatic remapping. """ super().__init__( name=name, device=device, use_compiled_names=use_compiled_names, )
[docs] def run( self, ctx: HubContext, model: TFLiteCompiledModel, input_data: dict[str, np.ndarray], *, device: str | None = None, use_compiled_names: bool = False, ) -> TFLiteInvocationResult: """Run inference on a compiled TFLite model. Keyword arguments override the defaults set in the constructor. If a keyword argument is not provided here, the value from the constructor is used. :param ctx: The execution context with device configuration. :param model: The compiled TFLite model (from :class:`TFLiteCompiler`). :param input_data: Dictionary mapping input tensor names to numpy arrays. :param device: Name of the target device. :param use_compiled_names: If ``False`` (default), input names are remapped to the compiled model's names before submission and output names are mapped back to the originals. :return: A :class:`TFLiteInvocationResult` with inference output. """ raise NoProviderError # Replaced by the component system.
# --------------------------------------------------------------------- # # QAI Hub provider # --------------------------------------------------------------------- # @TFLiteInvoker.provider(ProviderType.QAI_HUB) def _invoke_tflite_qai_hub( ctx: HubContext, model: TFLiteCompiledModel, input_data: dict[str, np.ndarray], *, device: str | None = None, use_compiled_names: bool = False, ) -> TFLiteInvocationResult: """Run inference on a TFLite model using Qualcomm AI Hub. :param ctx: The execution context with device configuration. :param model: The compiled TFLite model. :param input_data: Dictionary mapping input tensor names to numpy arrays. :param device: Name of the target device. :param use_compiled_names: If ``False`` (default), input names are remapped to the compiled model's names before submission and output names are mapped back to the originals. :return: A :class:`TFLiteInvocationResult` with inference output. """ # -- resolve model path (auto-transfer to local) ------------------- model_local_path = model.path.to_local(ctx).local() input_name_mapping = model.input_name_mapping output_name_mapping = model.output_name_mapping # -- validate device ------------------------------------------------ hub_device, device_name = resolve_qai_hub_device(ctx, device) # -- remap input names if needed ------------------------------------ if not use_compiled_names and input_name_mapping: input_data = _remap_input_names(input_data, input_name_mapping) # -- log input artifact --------------------------------------------- ctx.client.log_artifact( model.path, name="input", file_name=f"input_{model.path.file_name}", ctx=ctx, ) # -- invoke via QAI Hub -------------------------------------------- output_data = _invoke_via_qai_hub( ctx, model_local_path, hub_device, input_data ) # -- remap output names back to originals --------------------------- if not use_compiled_names and output_name_mapping: reverse_output = {v: k for k, v in output_name_mapping.items()} output_data = { reverse_output.get(name, name): array for name, array in output_data.items() } # -- save and log output artifact ----------------------------------- output_npz_path = ctx.artifact_dir / "output.npz" np.savez_compressed(output_npz_path, **output_data) ctx.client.log_artifact(output_npz_path, name="output_file") return TFLiteInvocationResult.from_current_run(ctx, output=output_data) # -- QAI Hub helpers ------------------------------------------------------- def _remap_input_names( input_data: dict[str, np.ndarray], name_mapping: dict[str, str], ) -> dict[str, np.ndarray]: """Remap input tensor names from original to compiled names. :param input_data: Original input data keyed by pre-compile names. :param name_mapping: Mapping from original to compiled names. :return: New dict with keys replaced according to `name_mapping`. :raises InvokeError: If an input name is not found in the mapping and does not match any compiled name either. """ compiled_names = set(name_mapping.values()) remapped: dict[str, np.ndarray] = {} for name, array in input_data.items(): if name in name_mapping: remapped[name_mapping[name]] = array elif name in compiled_names: # Already using compiled names — pass through. remapped[name] = array else: raise InvokeError( f"Input name '{name}' not found in the name mapping " f"and is not a compiled model input name. Known " f"original names: {sorted(name_mapping.keys())}. " f"Compiled names: {sorted(compiled_names)}." ) return remapped def _invoke_via_qai_hub( ctx: HubContext, model_path: Path, hub_device: hub.Device, input_data: dict[str, np.ndarray], ) -> dict[str, np.ndarray]: """Submit an inference job to Qualcomm AI Hub and download results. :param ctx: The execution context. :param model_path: Local path to the ``.tflite`` model. :param hub_device: Target QAI Hub device. :param input_data: Dictionary mapping input tensor names to numpy arrays. :return: Dictionary mapping output tensor names to numpy arrays. :raises InvokeError: If the inference job fails. """ hub_inputs: dict[str, list[np.ndarray]] = { name: [array] for name, array in input_data.items() } try: job: hub.InferenceJob = hub.submit_inference_job( model=model_path.as_posix(), device=hub_device, inputs=hub_inputs, ) except Exception as error: raise InvokeError( f"Failed to submit inference job to Qualcomm AI Hub: {error}" ) from error ctx.client.log_param("$qai_hub_job_id", job.job_id) ctx.client.log_param("qai_hub_inference_job_id", job.job_id) try: output_dataset = job.download_output_data() except Exception as error: raise InvokeError( f"Failed to download inference output from Qualcomm AI Hub: {error}" ) from error if output_dataset is None: raise InvokeError( f"Inference job returned no output data. Job ID: {job.job_id}" ) log_runtime_info(ctx.client, job.job_id, error_class=InvokeError) output_data: dict[str, np.ndarray] = {} for name, arrays in output_dataset.items(): if isinstance(arrays, list) and len(arrays) > 0: output_data[name] = arrays[0] else: output_data[name] = np.asarray(arrays) return output_data