# Copyright (C) 2026 Embedl AB
"""Core device abstractions.
This module defines the :class:`Device` dataclass and the protocols that
runners must implement (:class:`CommandRunner`, :class:`Connectable`).
"""
from __future__ import annotations
from collections.abc import Generator, Sequence
from contextlib import AbstractContextManager, contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Protocol, Self, TypeVar, runtime_checkable
from embedl_hub._internal.core.device.artifact_manager import (
RemoteArtifactManager,
)
from embedl_hub._internal.core.device.config.abc import ProviderConfig
from embedl_hub._internal.core.device.spec import DeviceSpec
from embedl_hub._internal.core.types import LocalPath, RemotePath
if TYPE_CHECKING:
from embedl_hub._internal.core.component.abc import Component
from embedl_hub._internal.tracking.run_log import DeviceLog
_T = TypeVar("_T", bound=ProviderConfig)
[docs]
class CommandResult(Protocol):
"""Protocol for the result of executing a command on a device."""
stdout: str | None
stderr: str | None
[docs]
class CommandRunner(Protocol):
"""Protocol for executing commands and transferring files on a device."""
self,
command: Sequence[str],
cwd: RemotePath | None = None,
*,
hide: bool = False,
) -> CommandResult: ...
[docs]
def get(self, source: RemotePath, destination: LocalPath) -> None: ...
[docs]
def put(self, source: LocalPath, destination: RemotePath) -> None: ...
@runtime_checkable
class Connectable(Protocol):
"""Protocol for runners that require an explicit connection step.
Runners implementing this protocol provide a :meth:`connect` context
manager that establishes and tears down the underlying connection.
The :attr:`is_active` property indicates whether a connection is
currently open.
"""
@property
def is_active(self) -> bool:
"""Whether a connection is currently established."""
...
def connect(self) -> AbstractContextManager[Self]: ...
[docs]
@dataclass
class Device:
"""Represents a target device for component execution.
Directory and artifact management is delegated to an internal
:class:`RemoteArtifactManager`. The most commonly used properties
(``artifact_base_dir``, ``artifact_dir``, ``last_download_dir``) and
the ``download_artifact_dir`` method are exposed as public delegates
for convenience.
:param name: A human-readable label for this device (e.g.
``'main'``, ``'jetson-orin'``). Used as the dictionary key in
:attr:`HubContext.devices` and shown in console output.
:param runner: Optional command runner for executing commands on the device.
:param spec: The device specification (platform and environment).
:param provider_type: The provider type identifier for this device.
:param provider_config: Default :class:`ProviderConfig` used by all
components unless overridden per-component.
:param provider_config_overrides: Optional per-component overrides
keyed by the concrete :class:`Component` subclass.
"""
name: str
runner: CommandRunner | None
spec: DeviceSpec
provider_type: str
provider_config: ProviderConfig | None = None
provider_config_overrides: dict[type[Component], ProviderConfig] = field(
default_factory=dict
)
_artifacts: RemoteArtifactManager = field(init=False, repr=False)
_in_context: bool = field(default=False, init=False, repr=False)
def __post_init__(self) -> None:
self._artifacts = RemoteArtifactManager(runner=self.runner)
# Propagate the device name to the runner for display purposes.
if self.runner is not None and hasattr(self.runner, "device_name"):
self.runner.device_name = self.name
@property
def is_active(self) -> bool:
"""Whether this device is inside a :meth:`connect` context."""
return self._in_context
# -- delegated properties ------------------------------------------------
@property
def artifact_base_dir(self) -> RemotePath | None:
"""Get the remote base directory for artifacts."""
return self._artifacts.base_dir
@artifact_base_dir.setter
def artifact_base_dir(self, value: RemotePath | None) -> None:
"""Set the remote base directory for artifacts.
:raises RuntimeError: If called while inside a device context.
"""
self._artifacts.base_dir = value
@property
def artifact_dir(self) -> RemotePath | None:
"""Get the artifact directory for the current run."""
return self._artifacts.artifact_dir
@property
def last_download_dir(self) -> LocalPath | None:
"""Get the local path of the most recent artifact download."""
return self._artifacts.last_download_dir
# -- delegated methods ---------------------------------------------------
[docs]
def download_artifact_dir(self, destination: LocalPath) -> LocalPath:
"""Download the current remote artifact directory to a local path.
:param destination: The local directory to download into.
:return: The local path where the artifacts were downloaded.
:raises RuntimeError: If not inside a device context.
:raises RuntimeError: If there is no command runner.
:raises RuntimeError: If there is no remote artifact directory.
"""
return self._artifacts.download_artifact_dir(destination)
def _set_artifact_subdir_name(self, subdir_name: str) -> None:
"""Set the artifact sub-directory name for this device.
**Protected** — called exclusively by the component system.
End users should not call this method directly.
Combined with :attr:`artifact_base_dir`, this determines the
:attr:`artifact_dir` (``<base_dir>/<subdir_name>``).
"""
self._artifacts.set_subdir_name(subdir_name)
def _clear_artifact_subdir_name(self) -> None:
"""Clear the artifact sub-directory name, resetting :attr:`artifact_dir`.
**Protected** — called exclusively by the component system.
End users should not call this method directly.
"""
self._artifacts.clear_subdir_name()
# -- provider config ------------------------------------------------------
[docs]
def get_provider_config(
self,
config_type: type[_T],
component_cls: type[Component],
) -> _T | None:
"""Resolve the provider config for a specific component class.
Walks the *component_cls* MRO looking for a matching override
in :attr:`provider_config_overrides`. If none is found, falls
back to :attr:`provider_config`. In both cases the config must
be an instance of *config_type*; if it is not, ``None`` is
returned.
This method is generic on *config_type*: the return type is
``T | None`` where ``T`` is the concrete config class passed in,
giving both runtime validation and static type narrowing::
cfg = dev.get_provider_config(TrtexecConfig, TensorRTCompiler)
# type: TrtexecConfig | None
:param config_type: The expected concrete
:class:`ProviderConfig` subclass.
:param component_cls: The component class requesting the config.
:return: The resolved config, or ``None`` if no matching config
is found.
"""
for cls in component_cls.__mro__:
if cls in self.provider_config_overrides:
cfg = self.provider_config_overrides[cls]
if isinstance(cfg, config_type):
return cfg
return None
if isinstance(self.provider_config, config_type):
return self.provider_config
return None
# -- device-specific logic -----------------------------------------------
[docs]
def create_device_log(self, name: str) -> DeviceLog:
"""Create a :class:`~embedl_hub._internal.tracking.run_log.DeviceLog` snapshot of
this device's current state.
:param name: Logical name for this device in the run record.
:return: A frozen :class:`~embedl_hub._internal.tracking.run_log.DeviceLog`
capturing the current ``artifact_dir``, ``device_spec``,
``provider_type``, ``provider_config``, and
``downloaded_artifact_dir``.
"""
from embedl_hub._internal.tracking.run_log import DeviceLog
return DeviceLog(
name=name,
artifact_dir=self.artifact_dir if self.is_active else None,
spec=self.spec,
provider_type=self.provider_type,
provider_metadata={},
downloaded_artifact_dir=self._artifacts.last_download_dir,
)
[docs]
@contextmanager
def connect(self) -> Generator[Self, None, None]:
"""Connect the device, establishing a runner connection if supported.
:yields: This device instance with an active connection.
"""
if self.runner is not None and isinstance(self.runner, Connectable):
runner_cm = self.runner.connect()
runner_cm.__enter__()
else:
runner_cm = None
try:
self._in_context = True
with self._artifacts.activate():
yield self
finally:
self._in_context = False
if runner_cm is not None:
runner_cm.__exit__(None, None, None)