# Copyright (C) 2026 Embedl AB
"""Component output base class for typed access to run log data."""
import types
from dataclasses import dataclass, replace
from datetime import UTC, datetime
from pathlib import Path
from typing import (
TYPE_CHECKING,
ClassVar,
Self,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
)
from embedl_hub._internal.core.types import LocalPath, RemotePath
from embedl_hub._internal.tracking.run_log import (
DeviceLog,
LoggedArtifact,
LoggedMetric,
LoggedParam,
RunLog,
)
if TYPE_CHECKING:
from embedl_hub._internal.core.context import HubContext
_VALID_FIELD_TYPES = (LoggedParam, LoggedMetric, LoggedArtifact)
def _unwrap_optional_logged_type(hint: object) -> type | None:
"""Return the inner ``LoggedX`` type if *hint* is ``LoggedX | None``.
Handles all equivalent spellings:
* ``LoggedParam | None`` / ``None | LoggedParam``
* ``Union[LoggedMetric, None]``
* ``Optional[LoggedArtifact]``
Returns ``None`` when *hint* is not an optional logged type.
"""
origin = get_origin(hint)
if origin is not Union and not isinstance(hint, types.UnionType):
return None
args = get_args(hint)
non_none = [a for a in args if a is not type(None)]
if len(non_none) != 1:
return None
inner = non_none[0]
return inner if inner in _VALID_FIELD_TYPES else None
[docs]
class MissingOutputError(Exception):
"""Raised when a required output field is not found in the run log."""
def __init__(self, field_name: str, field_type: type) -> None:
self.field_name = field_name
self.field_type = field_type
type_name = field_type.__name__
super().__init__(
f"Required output '{field_name}' of type {type_name} not found in run log."
)
[docs]
@dataclass(frozen=True)
class ComponentOutput:
"""Base class for component outputs with typed access to run data.
Subclasses are frozen dataclasses whose fields are annotated as
:class:`LoggedParam`, :class:`LoggedMetric`, or
:class:`LoggedArtifact`. They can be constructed manually or from
a :class:`RunLog` via :meth:`from_run_log`.
Example::
@dataclass(frozen=True)
class MyOutput(ComponentOutput):
accuracy: LoggedMetric
learning_rate: LoggedParam
model: LoggedArtifact
# Construct manually:
output = MyOutput(
accuracy=LoggedMetric(...),
learning_rate=LoggedParam(...),
model=LoggedArtifact(...),
)
# Or construct from a run log:
output = MyOutput.from_run_log(run_log)
print(output.accuracy.value)
If a field is defined in the class but no matching logged entity with
that name exists in the :class:`RunLog`, a :class:`MissingOutputError`
is raised.
The :attr:`run_log` field holds a reference to the :class:`RunLog`
the output was created from, or ``None`` if the output was
constructed manually.
"""
artifact_dir: LocalPath | None
devices: dict[str, DeviceLog]
run_log: RunLog | None
_output_fields: ClassVar[dict[str, type]]
_optional_output_fields: ClassVar[set[str]]
def __init_subclass__(cls, **kwargs) -> None:
"""Validate and apply dataclass to subclasses."""
super().__init_subclass__(**kwargs)
# Start with inherited tracked fields from the parent class.
tracked: dict[str, type] = {}
optional: set[str] = set()
for base in cls.__mro__[1:]:
if hasattr(base, "_output_fields"):
tracked.update(base._output_fields)
if hasattr(base, "_optional_output_fields"):
optional.update(base._optional_output_fields)
# Get type hints for this class's own annotations
if hasattr(cls, "__annotations__") and cls.__annotations__:
hints = get_type_hints(cls)
own_hints = {
name: hints[name]
for name in cls.__annotations__
if name in hints
and not name.startswith("_")
and name not in ("artifact_dir", "devices", "run_log")
}
# Track fields whose types are LoggedParam / LoggedMetric /
# LoggedArtifact (required) or the Optional variant thereof
# (e.g. ``LoggedArtifact | None``). Other fields (e.g. raw
# data dicts) are allowed but not populated by from_run_log.
for name, hint in own_hints.items():
if hint in _VALID_FIELD_TYPES:
tracked[name] = hint
else:
inner = _unwrap_optional_logged_type(hint)
if inner is not None:
tracked[name] = inner
optional.add(name)
cls._output_fields = tracked
cls._optional_output_fields = optional
[docs]
@classmethod
def from_run_log(cls, run_log: RunLog, **extra_kwargs: object) -> Self:
"""Construct an output instance from a :class:`RunLog`.
Matches fields by name from the :class:`RunLog`. Names logged
with a ``$`` prefix are matched to fields without the prefix
(e.g. ``$path`` → ``path``). Raises
:class:`MissingOutputError` if a required field is not found.
Raises :class:`ValueError` if both ``$name`` and ``name`` exist
in the run log.
Fields typed as ``LoggedX | None`` (or equivalently
``Optional[LoggedX]``) are *optional*: if the corresponding
entry is missing from the run log the field is set to ``None``
instead of raising.
Any additional keyword arguments are forwarded to the
constructor for fields that are not tracked in the run log
(e.g. raw data dictionaries).
:param run_log: The :class:`RunLog` to extract data from.
:param extra_kwargs: Extra fields to pass to the constructor.
:return: An instance of the output class with fields populated.
:raises MissingOutputError: If a required field is not in the
run log.
:raises ValueError: If both prefixed and unprefixed versions
exist.
"""
params = run_log.params
metrics = run_log.metrics
artifacts = run_log.named_artifacts
T = TypeVar("T")
def _resolve_name(field_name: str, source: dict[str, T]) -> str | None:
"""Resolve field name, checking for $ prefix conflicts."""
prefixed = f"${field_name}"
has_plain = field_name in source
has_prefixed = prefixed in source
if has_plain and has_prefixed:
raise ValueError(
f"Ambiguous field '{field_name}': both '{field_name}' and "
f"'{prefixed}' exist in the run log."
)
if has_prefixed:
return prefixed
if has_plain:
return field_name
return None
# Build kwargs for constructor
kwargs: dict[str, object] = {
"artifact_dir": run_log.artifact_dir,
"devices": run_log.devices,
"run_log": run_log,
}
is_optional = cls._optional_output_fields
for field_name, field_type in cls._output_fields.items():
if field_type is LoggedParam:
resolved = _resolve_name(field_name, params)
if resolved is None:
if field_name in is_optional:
kwargs[field_name] = None
continue
raise MissingOutputError(field_name, field_type)
kwargs[field_name] = params[resolved]
elif field_type is LoggedMetric:
resolved = _resolve_name(field_name, metrics)
if resolved is None:
if field_name in is_optional:
kwargs[field_name] = None
continue
raise MissingOutputError(field_name, field_type)
kwargs[field_name] = metrics[resolved]
elif field_type is LoggedArtifact:
resolved = _resolve_name(field_name, artifacts)
if resolved is None:
if field_name in is_optional:
kwargs[field_name] = None
continue
raise MissingOutputError(field_name, field_type)
kwargs[field_name] = artifacts[resolved]
kwargs.update(extra_kwargs)
return cls(**kwargs) # type: ignore[arg-type]
[docs]
@classmethod
def from_yaml(cls, yaml_path: LocalPath, **extra_kwargs: object) -> Self:
"""Construct an output instance from a ``run.yaml`` file.
This is a convenience wrapper that combines
:meth:`RunLog.from_yaml` and :meth:`from_run_log` into a
single call::
output = CompiledModel.from_yaml(run_dir / "run.yaml")
Any additional keyword arguments are forwarded to the
constructor for fields that are not tracked in the run log
(e.g. raw data dictionaries).
:param yaml_path: Path to a ``run.yaml`` metadata file.
:param extra_kwargs: Extra fields to pass to the constructor.
:return: An instance of the output class with fields populated.
:raises FileNotFoundError: If `yaml_path` does not exist.
:raises ValueError: If the YAML version is unsupported.
:raises MissingOutputError: If a required field is not in the
run log.
"""
run_log = RunLog.from_yaml(yaml_path)
return cls.from_run_log(run_log, **extra_kwargs)
[docs]
@classmethod
def from_current_run(
cls, ctx: "HubContext", **extra_kwargs: object
) -> Self:
"""Construct an output instance from the current run in the context.
Any additional keyword arguments are forwarded to the constructor
for fields that are not tracked in the run log (e.g. raw data
dictionaries).
:param ctx: The execution context with an active run.
:param extra_kwargs: Extra fields to pass to the constructor.
:return: An instance of the output class with fields populated.
:raises MissingOutputError: If a required field is not in the run log.
:raises RuntimeError: If no run is currently active.
"""
if ctx.client is None or ctx.client.current_run_log is None:
raise RuntimeError("No active run in the context.")
return cls.from_run_log(ctx.client.current_run_log, **extra_kwargs)
[docs]
def download_device_artifacts(
self,
ctx: "HubContext",
device_name: str,
target_dir: LocalPath | None = None,
) -> Self:
"""Download a device's remote artifact directory.
Uses the ``artifact_dir`` recorded in the device's
:class:`~embedl_hub._internal.tracking.run_log.DeviceLog` and the live
runner from ``ctx.devices`` to transfer the remote directory.
The device connection must still be open (i.e. the caller must
still be inside ``with ctx:``).
:param ctx: The active hub context with device connections.
:param device_name: Name of the device whose artifacts to
download.
:param target_dir: Local directory to download into. If
``None``, defaults to
``<artifact_dir>/../remote_<device_name>``.
:return: A new :class:`ComponentOutput` with the device's
``downloaded_artifact_dir`` updated.
:raises RuntimeError: If the device is not found in the output
or context, has no remote artifact directory, or has no
command runner.
"""
dev_log = self.devices.get(device_name)
if dev_log is None:
raise RuntimeError(f"No device '{device_name}' in this output.")
if dev_log.artifact_dir is None:
raise RuntimeError(
f"Device '{device_name}' has no remote artifact "
f"directory to download."
)
device = ctx.devices.get(device_name)
if device is None:
raise RuntimeError(f"Device '{device_name}' not found in context.")
if device.runner is None:
raise RuntimeError(
f"Device '{device_name}' has no command runner."
)
if target_dir is None:
if self.artifact_dir is None:
raise RuntimeError(
"Cannot determine download destination: output has "
"no artifact_dir. Provide 'target_dir' explicitly."
)
target_dir = self.artifact_dir.parent / f"remote_{device_name}"
from embedl_hub._internal.core.device.transfer import get_directory
get_directory(device.runner, dev_log.artifact_dir, target_dir)
new_devices = dict(self.devices)
new_devices[device_name] = replace(
dev_log, downloaded_artifact_dir=target_dir
)
return replace(self, devices=new_devices)
[docs]
@dataclass(frozen=True)
class CompiledModel(ComponentOutput):
"""Base class for compiled model outputs.
Every compiler component produces an output that carries at least:
* **path** — a :class:`LoggedArtifact` pointing at the compiled
model.
Optionally:
* **input** — a :class:`LoggedArtifact` pointing at the original
(pre-compilation) model, or ``None`` when the source model is
not tracked.
Runtime-specific subclasses may add extra fields (e.g. tensor name
mappings).
"""
path: LoggedArtifact
input: LoggedArtifact | None = None
[docs]
@classmethod
def from_path(
cls,
path: LocalPath | RemotePath,
**extra_kwargs: object,
) -> Self:
"""Create a compiled model output from a model file path.
Builds a minimal :class:`LoggedArtifact` for the ``path`` field.
The ``input`` field is left as ``None`` and no ``run_log`` is
attached — the component system will create an ``IMPORT`` run
automatically when this output is passed to another component.
For local paths the file size is read from disk; for remote
paths it defaults to ``0``.
Any additional keyword arguments are forwarded to the
constructor (e.g. ``input_name_mapping`` on subclasses).
:param path: Local or remote path to the compiled model file.
:param extra_kwargs: Extra fields to pass to the constructor.
:return: A new instance of this class.
"""
file_path: LocalPath | RemotePath
if isinstance(path, Path):
file_path = LocalPath(path)
file_size = path.stat().st_size if path.exists() else 0
else:
file_path = RemotePath(path)
file_size = 0
artifact = LoggedArtifact(
id="",
file_name=file_path.name,
file_size=file_size,
logged_at=datetime.now(UTC),
file_path=file_path,
name="$path",
)
return cls(
artifact_dir=None,
devices={},
run_log=None,
path=artifact,
**extra_kwargs, # type: ignore[arg-type]
)