Source code for embedl_hub._internal.core.component.abc

# Copyright (C) 2026 Embedl AB

"""Abstract base classes for component infrastructure.

This module provides the base classes and decorators for defining components
that can be dispatched to different providers based on the execution context.

Components define their interface (parameters and return type) via
``run``.  Provider implementations are registered externally using
the :meth:`Component.provider` class-method decorator::

    class MyCompiler(Component):
        run_type = RunType.COMPILE

        def __init__(
            self,
            *,
            optimize: bool = True,
        ) -> None:
            \"\"\"Create a new compiler.\"\"\"
            super().__init__(optimize=optimize)

        def run(
            self,
            ctx: HubContext,
            model_path: Path,
            *,
            optimize: bool = True,
        ) -> MyOutput:
            \"\"\"Compile a model.\"\"\"
            ...

    @MyCompiler.provider(ProviderType.EMBEDL_ONNXRUNTIME)
    def _compile_cli(
        ctx: HubContext,
        model_path: Path,
        *,
        optimize: bool = True,
    ) -> MyOutput:
        # *optimize* is already resolved: call-time value > constructor default
        ...

Each component must define ``__init__`` with the same keyword-only
parameters as ``run`` and call ``super().__init__(...)`` to store
them as defaults.  The metaclass validates the parameter names match::

    compiler = MyCompiler(optimize=False)

stores ``optimize=False`` as the default, which will be used whenever
the caller does not pass ``optimize`` explicitly at call time.
"""

from __future__ import annotations

import inspect
from abc import ABC, ABCMeta
from collections.abc import Callable
from typing import ClassVar, get_type_hints

from embedl_hub._internal.core.component.dispatch import (
    create_dispatcher,
    validate_provider_signature,
)
from embedl_hub._internal.core.context import HubContext
from embedl_hub._internal.tracking.rest_api import RunType


[docs] class NoProviderError(NotImplementedError): """Raised when no provider is registered for the active device type."""
# --------------------------------------------------------------------------- # Metaclass # --------------------------------------------------------------------------- class ComponentMeta(ABCMeta): """Metaclass that validates ``__init__`` matches ``run``, wraps ``run`` with the internal dispatch logic, and exposes a :meth:`~Component.provider` classmethod for registering device-specific implementations. """ def __new__( mcs, name: str, bases: tuple[type, ...], namespace: dict[str, object], **kwargs: object, ) -> ComponentMeta: cls = super().__new__(mcs, name, bases, namespace, **kwargs) # Skip validation for the abstract Component base class is_base_component = name == "Component" and not any( hasattr(b, "_providers") for b in bases ) if is_base_component: cls._providers = {} # type: ignore[attr-defined] cls._keyword_params = {} # type: ignore[attr-defined] cls._positional_params = [] # type: ignore[attr-defined] return cls # ---- locate run (the interface) -------------------------------- call_method = namespace.get("run") if call_method is None or not callable(call_method): # No run in this class — might be a mixin / intermediate. return cls # ---- validate run_type ------------------------------------------ if "run_type" not in namespace: raise TypeError( f"'{name}' must define a 'run_type' class variable " f"(e.g. run_type = RunType.COMPILE)." ) run_type = namespace["run_type"] if not isinstance(run_type, RunType): raise TypeError( f"'{name}.run_type' must be a RunType enum member, " f"got {type(run_type).__name__}." ) # ---- extract run signature -------------------------------------- call_sig = inspect.signature(call_method) params = list(call_sig.parameters.values()) if len(params) < 2: raise TypeError( f"'{name}.run' must have at least 'self' and " "'ctx: HubContext' parameters." ) # Validate ctx parameter type try: hints = get_type_hints(call_method) except Exception: hints = {} ctx_param = params[1] # params[0] is self ctx_type = hints.get(ctx_param.name, ctx_param.annotation) # Accept both the actual class and its string name (the latter # appears when ``from __future__ import annotations`` is active # and ``get_type_hints`` cannot resolve all forward references). if ctx_type is not HubContext and ctx_type != HubContext.__name__: raise TypeError( f"'{name}.run' first parameter after 'self' must be " f"annotated as HubContext, got '{ctx_param.annotation}'." ) # Separate positional and keyword-only parameters (skip self, ctx) positional_params: list[str] = [] keyword_params: dict[str, object] = {} for p in params[2:]: if p.kind == inspect.Parameter.KEYWORD_ONLY: keyword_params[p.name] = p.default elif p.kind in ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ): positional_params.append(p.name) # ---- validate device / devices kwarg ---------------------------- has_device = "device" in keyword_params has_devices = "devices" in keyword_params if has_device and has_devices: raise TypeError( f"'{name}' declares both 'device' and 'devices' in " f"'run'. Use exactly one: 'device: str | None = None' " f"for single-device components, or " f"'devices: list[str] | None = None' for multi-device " f"components." ) if not has_device and not has_devices: raise TypeError( f"'{name}' must declare either 'device' or 'devices' " f"as a keyword-only parameter in 'run' (e.g. " f"'device: str | None = None' or " f"'devices: list[str] | None = None')." ) # ---- validate __init__ ------------------------------------------ if "__init__" not in namespace: raise TypeError( f"'{name}' must define '__init__' with the same " "keyword-only parameters as 'run' (plus " "'name: str | None = None'). Call " "'super().__init__(...)' to store them as defaults." ) init_method = namespace["__init__"] init_sig = inspect.signature(init_method) init_kw: dict[str, object] = { p.name: p.default for p in list(init_sig.parameters.values())[1:] if p.kind == inspect.Parameter.KEYWORD_ONLY } # Enforce that __init__ declares 'name: str | None = None' if "name" not in init_kw: raise TypeError( f"'{name}.__init__' must declare " "'name: str | None = None' as a keyword-only " "parameter." ) # Compare __init__ kwargs (minus 'name') against run kwargs init_set = set(init_kw) - {"name"} call_set = set(keyword_params) if init_set != call_set: missing = call_set - init_set extra = init_set - call_set parts: list[str] = [] if missing: parts.append(f"missing: {sorted(missing)}") if extra: parts.append(f"extra: {sorted(extra)}") raise TypeError( f"'{name}.__init__' keyword parameters don't match " f"'run': {', '.join(parts)}." ) # Store on the class for introspection and provider validation cls._providers = {} # type: ignore[attr-defined] cls._keyword_params = keyword_params # type: ignore[attr-defined] cls._positional_params = positional_params # type: ignore[attr-defined] cls._call_interface = call_method # type: ignore[attr-defined] # Replace run with the dispatcher setattr( cls, "run", create_dispatcher( name, call_method, positional_params, keyword_params ), ) # Remove run from __abstractmethods__ if it was inherited if ( hasattr(cls, "__abstractmethods__") and "run" in cls.__abstractmethods__ ): cls.__abstractmethods__ = cls.__abstractmethods__ - {"run"} return cls # --------------------------------------------------------------------------- # Base Component # ---------------------------------------------------------------------------
[docs] class Component(ABC, metaclass=ComponentMeta): """Abstract base class for all components. Subclasses define their interface by implementing ``run`` with the desired positional and keyword-only parameters. The body should just be ``...`` — the actual work is performed by *device-specific implementations* registered via :meth:`provider`. Each subclass must define a ``run_type`` class variable indicating which :class:`~embedl_hub._internal.tracking.RunType` to use when logging runs. Each subclass must also define ``__init__`` with the same keyword-only parameters as ``run``, calling ``super().__init__(...)`` to store them as defaults. The metaclass validates the parameter names match. Values passed to the constructor serve as defaults that can be overridden at each call site. Example:: class MyCompiler(Component): run_type = RunType.COMPILE def __init__( self, *, optimize: bool = True, ) -> None: \"\"\"Create a new compiler.\"\"\" super().__init__(optimize=optimize) def run( self, ctx: HubContext, model: Path, *, optimize: bool = True, ) -> MyOutput: \"\"\"Compile *model*.\"\"\" ... @MyCompiler.provider(ProviderType.EMBEDL_ONNXRUNTIME) def _compile_cli( ctx: HubContext, model: Path, *, optimize: bool = True, ) -> MyOutput: # *optimize* is already resolved ... compiler = MyCompiler(optimize=False) # default optimize=False compiler.run(ctx, some_model) # uses optimize=False compiler.run(ctx, some_model, optimize=True) # overrides to True """ run_type: ClassVar[RunType] _providers: ClassVar[dict[str, Callable]] _keyword_params: ClassVar[dict[str, object]] _positional_params: ClassVar[list[str]]
[docs] @classmethod def providers(cls) -> list[str]: """Return the registered provider type names. :return: A list of provider-type strings. """ return list(cls._providers)
# Auto-generated __init__ ------------------------------------------ def __init__(self, **kwargs: object) -> None: """Store keyword argument values as instance attributes. Subclasses call ``super().__init__(...)`` from their own ``__init__`` — parameter validation is handled by the subclass signature (at call time) and the metaclass (at class creation time). The ``name`` keyword is handled specially: it is stored as a read-only :attr:`name` property and defaults to the class name when ``None``. """ raw_name = kwargs.pop("name", None) self._name: str = ( raw_name if raw_name is not None else type(self).__name__ ) for name, value in kwargs.items(): setattr(self, name, value) @property def name(self) -> str: """The display name of this component instance. Defaults to the class name when not provided at construction time. Used as the run name in the tracking system. """ return self._name # Provider registration --------------------------------------------
[docs] @classmethod def provider(cls, provider_type: str) -> Callable: """Decorator that registers a function as a provider. Usage:: @MyComponent.provider(ProviderType.EMBEDL_ONNXRUNTIME) def _run_cli(ctx: HubContext, ...) -> OutputType: ... The decorated function must have the same positional and keyword-only parameters as the component's ``run`` (excluding ``self``). Keyword argument values will be pre-resolved by the component system (call-time overrides beat constructor defaults). :param provider_type: A :class:`ProviderType` member or equivalent string (e.g. ``ProviderType.EMBEDL_ONNXRUNTIME``). :return: A decorator that registers the function and returns it unchanged. """ def decorator(fn: Callable) -> Callable: if provider_type in cls._providers: raise ValueError( f"Duplicate provider type: '{provider_type}' on " f"'{cls.__name__}'." ) validate_provider_signature(cls, provider_type, fn) cls._providers[provider_type] = fn return fn return decorator