# Copyright (C) 2026 Embedl AB
"""Device manager for creating and managing device instances."""
from __future__ import annotations
from typing import TYPE_CHECKING
from embedl_hub._internal.core.component.provider_type import ProviderType
from embedl_hub._internal.core.device.abc import Device
from embedl_hub._internal.core.device.config.embedl_onnxruntime import (
EmbedlONNXRuntimeConfig,
)
from embedl_hub._internal.core.device.config.trtexec import TrtexecConfig
from embedl_hub._internal.core.device.spec import DeviceSpec
from embedl_hub._internal.core.device.ssh import SSHCommandRunner, SSHConfig
if TYPE_CHECKING:
from embedl_hub._internal.core.component.abc import Component
[docs]
class DeviceManager:
"""Manager for creating device instances."""
[docs]
@classmethod
def get_ssh_device(
cls,
config: SSHConfig,
*,
name: str = "main",
provider_type: str = ProviderType.EMBEDL_ONNXRUNTIME,
) -> Device:
"""Create a device with an SSH command runner.
:param config: SSH connection configuration.
:param name: Human-readable label for this device.
:param provider_type: The provider type identifier
(e.g. ``ProviderType.EMBEDL_ONNXRUNTIME``,
``ProviderType.TRTEXEC``).
:return: A Device instance configured for SSH access.
"""
runner = SSHCommandRunner(config)
return Device(
name=name,
runner=runner,
spec=DeviceSpec(),
provider_type=provider_type,
)
[docs]
@classmethod
def get_qai_hub_device(
cls,
device_name: str,
*,
name: str = "main",
) -> Device:
"""Create a device for Qualcomm AI Hub.
:param device_name: The QAI Hub device name (e.g., 'Samsung Galaxy S24').
:param name: Human-readable label for this device.
:return: A Device instance configured for QAI Hub.
"""
return Device(
name=name,
runner=None,
spec=DeviceSpec(device_name=device_name),
provider_type=ProviderType.QAI_HUB,
)
[docs]
@classmethod
def get_embedl_onnxruntime_device(
cls,
config: SSHConfig,
*,
name: str = "main",
provider_config: EmbedlONNXRuntimeConfig = EmbedlONNXRuntimeConfig(),
overrides: dict[type[Component], EmbedlONNXRuntimeConfig]
| None = None,
) -> Device:
"""Create an SSH device for ONNX Runtime compilation.
:param config: SSH connection configuration.
:param name: Human-readable label for this device.
:param provider_config: Typed configuration for the
``embedl-onnxruntime`` CLI provider.
:param overrides: Optional per-component config overrides
keyed by concrete component class.
:return: A Device instance configured for remote ORT compilation.
"""
runner = SSHCommandRunner(config)
return Device(
name=name,
runner=runner,
spec=DeviceSpec(),
provider_type=ProviderType.EMBEDL_ONNXRUNTIME,
provider_config=provider_config,
provider_config_overrides=dict(overrides) if overrides else {},
)
[docs]
@classmethod
def get_tensorrt_device(
cls,
config: SSHConfig,
*,
name: str = "main",
provider_config: TrtexecConfig = TrtexecConfig(),
overrides: dict[type[Component], TrtexecConfig] | None = None,
) -> Device:
"""Create an SSH device for TensorRT compilation.
:param config: SSH connection configuration.
:param name: Human-readable label for this device.
:param provider_config: Typed configuration for the ``trtexec``
CLI provider.
:param overrides: Optional per-component config overrides
keyed by concrete component class.
:return: A Device instance configured for remote TensorRT
compilation.
"""
runner = SSHCommandRunner(config)
return Device(
name=name,
runner=runner,
spec=DeviceSpec(),
provider_type=ProviderType.TRTEXEC,
provider_config=provider_config,
provider_config_overrides=dict(overrides) if overrides else {},
)
[docs]
@classmethod
def get_aws_device(
cls,
device_name: str,
*,
name: str = "main",
) -> Device:
"""Create a device for the Embedl device cloud (AWS Device Farm).
:param device_name: The Embedl device cloud device name.
:param name: Human-readable label for this device.
:return: A Device instance configured for AWS Device Farm.
"""
return Device(
name=name,
runner=None,
spec=DeviceSpec(device_name=device_name),
provider_type=ProviderType.AWS,
)