Source code for embedl_hub._internal.core.component.output

# Copyright (C) 2026 Embedl AB

"""Component output base class for typed access to run log data."""

import types
from dataclasses import dataclass, replace
from datetime import UTC, datetime
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    ClassVar,
    Self,
    TypeVar,
    Union,
    get_args,
    get_origin,
    get_type_hints,
)

from embedl_hub._internal.core.types import LocalPath, RemotePath
from embedl_hub._internal.tracking.run_log import (
    DeviceLog,
    LoggedArtifact,
    LoggedMetric,
    LoggedParam,
    RunLog,
)

if TYPE_CHECKING:
    from embedl_hub._internal.core.context import HubContext

_VALID_FIELD_TYPES = (LoggedParam, LoggedMetric, LoggedArtifact)


def _unwrap_optional_logged_type(hint: object) -> type | None:
    """Return the inner ``LoggedX`` type if *hint* is ``LoggedX | None``.

    Handles all equivalent spellings:

    * ``LoggedParam | None``  /  ``None | LoggedParam``
    * ``Union[LoggedMetric, None]``
    * ``Optional[LoggedArtifact]``

    Returns ``None`` when *hint* is not an optional logged type.
    """
    origin = get_origin(hint)
    if origin is not Union and not isinstance(hint, types.UnionType):
        return None
    args = get_args(hint)
    non_none = [a for a in args if a is not type(None)]
    if len(non_none) != 1:
        return None
    inner = non_none[0]
    return inner if inner in _VALID_FIELD_TYPES else None


[docs] class MissingOutputError(Exception): """Raised when a required output field is not found in the run log.""" def __init__(self, field_name: str, field_type: type) -> None: self.field_name = field_name self.field_type = field_type type_name = field_type.__name__ super().__init__( f"Required output '{field_name}' of type {type_name} not found in run log." )
[docs] @dataclass(frozen=True) class ComponentOutput: """Base class for component outputs with typed access to run data. Subclasses are frozen dataclasses whose fields are annotated as :class:`LoggedParam`, :class:`LoggedMetric`, or :class:`LoggedArtifact`. They can be constructed manually or from a :class:`RunLog` via :meth:`from_run_log`. Example:: @dataclass(frozen=True) class MyOutput(ComponentOutput): accuracy: LoggedMetric learning_rate: LoggedParam model: LoggedArtifact # Construct manually: output = MyOutput( accuracy=LoggedMetric(...), learning_rate=LoggedParam(...), model=LoggedArtifact(...), ) # Or construct from a run log: output = MyOutput.from_run_log(run_log) print(output.accuracy.value) If a field is defined in the class but no matching logged entity with that name exists in the :class:`RunLog`, a :class:`MissingOutputError` is raised. The :attr:`run_log` field holds a reference to the :class:`RunLog` the output was created from, or ``None`` if the output was constructed manually. """ artifact_dir: LocalPath | None devices: dict[str, DeviceLog] run_log: RunLog | None _output_fields: ClassVar[dict[str, type]] _optional_output_fields: ClassVar[set[str]] def __init_subclass__(cls, **kwargs) -> None: """Validate and apply dataclass to subclasses.""" super().__init_subclass__(**kwargs) # Start with inherited tracked fields from the parent class. tracked: dict[str, type] = {} optional: set[str] = set() for base in cls.__mro__[1:]: if hasattr(base, "_output_fields"): tracked.update(base._output_fields) if hasattr(base, "_optional_output_fields"): optional.update(base._optional_output_fields) # Get type hints for this class's own annotations if hasattr(cls, "__annotations__") and cls.__annotations__: hints = get_type_hints(cls) own_hints = { name: hints[name] for name in cls.__annotations__ if name in hints and not name.startswith("_") and name not in ("artifact_dir", "devices", "run_log") } # Track fields whose types are LoggedParam / LoggedMetric / # LoggedArtifact (required) or the Optional variant thereof # (e.g. ``LoggedArtifact | None``). Other fields (e.g. raw # data dicts) are allowed but not populated by from_run_log. for name, hint in own_hints.items(): if hint in _VALID_FIELD_TYPES: tracked[name] = hint else: inner = _unwrap_optional_logged_type(hint) if inner is not None: tracked[name] = inner optional.add(name) cls._output_fields = tracked cls._optional_output_fields = optional
[docs] @classmethod def from_run_log(cls, run_log: RunLog, **extra_kwargs: object) -> Self: """Construct an output instance from a :class:`RunLog`. Matches fields by name from the :class:`RunLog`. Names logged with a ``$`` prefix are matched to fields without the prefix (e.g. ``$path`` → ``path``). Raises :class:`MissingOutputError` if a required field is not found. Raises :class:`ValueError` if both ``$name`` and ``name`` exist in the run log. Fields typed as ``LoggedX | None`` (or equivalently ``Optional[LoggedX]``) are *optional*: if the corresponding entry is missing from the run log the field is set to ``None`` instead of raising. Any additional keyword arguments are forwarded to the constructor for fields that are not tracked in the run log (e.g. raw data dictionaries). :param run_log: The :class:`RunLog` to extract data from. :param extra_kwargs: Extra fields to pass to the constructor. :return: An instance of the output class with fields populated. :raises MissingOutputError: If a required field is not in the run log. :raises ValueError: If both prefixed and unprefixed versions exist. """ params = run_log.params metrics = run_log.metrics artifacts = run_log.named_artifacts T = TypeVar("T") def _resolve_name(field_name: str, source: dict[str, T]) -> str | None: """Resolve field name, checking for $ prefix conflicts.""" prefixed = f"${field_name}" has_plain = field_name in source has_prefixed = prefixed in source if has_plain and has_prefixed: raise ValueError( f"Ambiguous field '{field_name}': both '{field_name}' and " f"'{prefixed}' exist in the run log." ) if has_prefixed: return prefixed if has_plain: return field_name return None # Build kwargs for constructor kwargs: dict[str, object] = { "artifact_dir": run_log.artifact_dir, "devices": run_log.devices, "run_log": run_log, } is_optional = cls._optional_output_fields for field_name, field_type in cls._output_fields.items(): if field_type is LoggedParam: resolved = _resolve_name(field_name, params) if resolved is None: if field_name in is_optional: kwargs[field_name] = None continue raise MissingOutputError(field_name, field_type) kwargs[field_name] = params[resolved] elif field_type is LoggedMetric: resolved = _resolve_name(field_name, metrics) if resolved is None: if field_name in is_optional: kwargs[field_name] = None continue raise MissingOutputError(field_name, field_type) kwargs[field_name] = metrics[resolved] elif field_type is LoggedArtifact: resolved = _resolve_name(field_name, artifacts) if resolved is None: if field_name in is_optional: kwargs[field_name] = None continue raise MissingOutputError(field_name, field_type) kwargs[field_name] = artifacts[resolved] kwargs.update(extra_kwargs) return cls(**kwargs) # type: ignore[arg-type]
[docs] @classmethod def from_yaml(cls, yaml_path: LocalPath, **extra_kwargs: object) -> Self: """Construct an output instance from a ``run.yaml`` file. This is a convenience wrapper that combines :meth:`RunLog.from_yaml` and :meth:`from_run_log` into a single call:: output = CompiledModel.from_yaml(run_dir / "run.yaml") Any additional keyword arguments are forwarded to the constructor for fields that are not tracked in the run log (e.g. raw data dictionaries). :param yaml_path: Path to a ``run.yaml`` metadata file. :param extra_kwargs: Extra fields to pass to the constructor. :return: An instance of the output class with fields populated. :raises FileNotFoundError: If `yaml_path` does not exist. :raises ValueError: If the YAML version is unsupported. :raises MissingOutputError: If a required field is not in the run log. """ run_log = RunLog.from_yaml(yaml_path) return cls.from_run_log(run_log, **extra_kwargs)
[docs] @classmethod def from_current_run( cls, ctx: "HubContext", **extra_kwargs: object ) -> Self: """Construct an output instance from the current run in the context. Any additional keyword arguments are forwarded to the constructor for fields that are not tracked in the run log (e.g. raw data dictionaries). :param ctx: The execution context with an active run. :param extra_kwargs: Extra fields to pass to the constructor. :return: An instance of the output class with fields populated. :raises MissingOutputError: If a required field is not in the run log. :raises RuntimeError: If no run is currently active. """ if ctx.client is None or ctx.client.current_run_log is None: raise RuntimeError("No active run in the context.") return cls.from_run_log(ctx.client.current_run_log, **extra_kwargs)
[docs] def download_device_artifacts( self, ctx: "HubContext", device_name: str, target_dir: LocalPath | None = None, ) -> Self: """Download a device's remote artifact directory. Uses the ``artifact_dir`` recorded in the device's :class:`~embedl_hub._internal.tracking.run_log.DeviceLog` and the live runner from ``ctx.devices`` to transfer the remote directory. The device connection must still be open (i.e. the caller must still be inside ``with ctx:``). :param ctx: The active hub context with device connections. :param device_name: Name of the device whose artifacts to download. :param target_dir: Local directory to download into. If ``None``, defaults to ``<artifact_dir>/../remote_<device_name>``. :return: A new :class:`ComponentOutput` with the device's ``downloaded_artifact_dir`` updated. :raises RuntimeError: If the device is not found in the output or context, has no remote artifact directory, or has no command runner. """ dev_log = self.devices.get(device_name) if dev_log is None: raise RuntimeError(f"No device '{device_name}' in this output.") if dev_log.artifact_dir is None: raise RuntimeError( f"Device '{device_name}' has no remote artifact " f"directory to download." ) device = ctx.devices.get(device_name) if device is None: raise RuntimeError(f"Device '{device_name}' not found in context.") if device.runner is None: raise RuntimeError( f"Device '{device_name}' has no command runner." ) if target_dir is None: if self.artifact_dir is None: raise RuntimeError( "Cannot determine download destination: output has " "no artifact_dir. Provide 'target_dir' explicitly." ) target_dir = self.artifact_dir.parent / f"remote_{device_name}" from embedl_hub._internal.core.device.transfer import get_directory get_directory(device.runner, dev_log.artifact_dir, target_dir) new_devices = dict(self.devices) new_devices[device_name] = replace( dev_log, downloaded_artifact_dir=target_dir ) return replace(self, devices=new_devices)
[docs] @dataclass(frozen=True) class CompiledModel(ComponentOutput): """Base class for compiled model outputs. Every compiler component produces an output that carries at least: * **path** — a :class:`LoggedArtifact` pointing at the compiled model. Optionally: * **input** — a :class:`LoggedArtifact` pointing at the original (pre-compilation) model, or ``None`` when the source model is not tracked. Runtime-specific subclasses may add extra fields (e.g. tensor name mappings). """ path: LoggedArtifact input: LoggedArtifact | None = None
[docs] @classmethod def from_path( cls, path: LocalPath | RemotePath, **extra_kwargs: object, ) -> Self: """Create a compiled model output from a model file path. Builds a minimal :class:`LoggedArtifact` for the ``path`` field. The ``input`` field is left as ``None`` and no ``run_log`` is attached — the component system will create an ``IMPORT`` run automatically when this output is passed to another component. For local paths the file size is read from disk; for remote paths it defaults to ``0``. Any additional keyword arguments are forwarded to the constructor (e.g. ``input_name_mapping`` on subclasses). :param path: Local or remote path to the compiled model file. :param extra_kwargs: Extra fields to pass to the constructor. :return: A new instance of this class. """ file_path: LocalPath | RemotePath if isinstance(path, Path): file_path = LocalPath(path) file_size = path.stat().st_size if path.exists() else 0 else: file_path = RemotePath(path) file_size = 0 artifact = LoggedArtifact( id="", file_name=file_path.name, file_size=file_size, logged_at=datetime.now(UTC), file_path=file_path, name="$path", ) return cls( artifact_dir=None, devices={}, run_log=None, path=artifact, **extra_kwargs, # type: ignore[arg-type] )