Source code for embedl_hub._internal.core.device.ssh

# Copyright (C) 2026 Embedl AB

"""SSH-based command runner for remote device execution.

This module provides an SSH command runner implementation using asyncssh
for executing commands and transferring files on remote devices.
"""

import asyncio
import io
import warnings
from collections.abc import Generator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Self, cast

import asyncssh
from rich.console import Console
from rich.text import Text

from embedl_hub._internal.core.device.abc import (
    CommandResult,
    CommandRunner,
    Connectable,
)
from embedl_hub._internal.core.types import LocalPath, RemotePath


[docs] class SSHConnectionError(Exception): """Raised when an SSH operation is attempted without an established connection."""
[docs] @dataclass class SSHConfig: """Configuration for SSH connections. :param host: The hostname or IP address of the remote device. :param port: The SSH port (default: 22). :param username: The username for authentication. :param password: Optional password for authentication. Using password authentication is discouraged for security reasons and will raise a warning. Prefer key-based authentication instead. :param private_key_path: Optional path to private key file. If not provided, asyncssh will attempt to use available authentication methods. :param known_hosts: Path to known_hosts file, or None to disable checking. """ host: str username: str port: int = 22 password: str | None = None private_key_path: LocalPath | None = None known_hosts: LocalPath | None = None
_console = Console(stderr=True) """Shared Rich console for SSH output (writes to stderr so it doesn't mix with real program stdout when piped).""" _BORDER_STYLE = "dim cyan" _INDENT = " " _PREFIX = Text(f"{_INDENT}│ ", style="dim cyan") def _box_top(console: Console, title: str | None = None) -> None: """Print the top edge of a box: `` ┌─── title ───…``.""" width = console.width - len(_INDENT) - 1 # space for indent + ┌ if title: label = f" {title} " left = max((width - len(label)) // 2, 1) right = max(width - left - len(label), 1) line = f"{_INDENT}{'─' * left}{label}{'─' * right}" else: line = f"{_INDENT}{'─' * width}" console.print(Text(line, style=_BORDER_STYLE), highlight=False) def _box_bottom(console: Console) -> None: """Print the bottom edge of a box: `` └───…``.""" width = console.width - len(_INDENT) - 1 console.print( Text(f"{_INDENT}{'─' * width}", style=_BORDER_STYLE), highlight=False, ) class _TeeWriter(io.RawIOBase): """Captures bytes *and* prints each line to the console with a prefix. asyncssh closes file objects passed as stdout/stderr, so this wrapper prevents the real stream from being closed while also accumulating all written bytes so they can be retrieved afterwards. """ def __init__(self, console: Console, style: str = "") -> None: self._console = console self._style = style self._parts: list[bytes] = [] self._line_buf = "" def write(self, b: bytes | bytearray) -> int: # type: ignore[override] data = bytes(b) self._parts.append(data) text = self._line_buf + data.decode("utf-8", errors="replace") lines = text.split("\n") # Last element is the incomplete line (may be empty) self._line_buf = lines.pop() for line in lines: self._console.print( _PREFIX + Text(line, style=self._style), highlight=False, ) return len(data) def flush_remaining(self) -> None: """Print any remaining partial line in the buffer.""" if self._line_buf: self._console.print( _PREFIX + Text(self._line_buf, style=self._style), highlight=False, ) self._line_buf = "" @property def captured(self) -> str: """Return all captured output decoded as UTF-8.""" return b"".join(self._parts).decode("utf-8", errors="replace") def close(self) -> None: # Intentionally do NOT close the underlying stream. pass
[docs] class SSHCommandRunner(CommandRunner, Connectable): """SSH-based command runner for remote device execution. Implements the :class:`CommandRunner` and :class:`Connectable` protocols using asyncssh for SSH connections. Call :meth:`connect` to obtain a context manager that establishes the connection. Example:: runner = SSHCommandRunner(config) with runner.connect(): runner.run(["ls", "-la"]) """ device_name: str | None """Human-readable device name, set by the owning :class:`~embedl_hub._internal.core.device.abc.Device`.""" def __init__(self, config: SSHConfig) -> None: """Initialize the SSH command runner. :param config: SSH connection configuration. """ if config.password is not None: warnings.warn( "Using password authentication is discouraged for security reasons. " "Consider using key-based authentication instead.", UserWarning, stacklevel=2, ) self._config = config self._connection: asyncssh.SSHClientConnection | None = None self._loop: asyncio.AbstractEventLoop | None = None self.device_name: str | None = None @property def is_active(self) -> bool: """Whether an SSH connection is currently established.""" return self._connection is not None async def _connect(self) -> asyncssh.SSHClientConnection: """Establish an SSH connection to the remote device. :return: An active SSH client connection. """ connect_kwargs: dict[str, object] = { "host": self._config.host, "port": self._config.port, "username": self._config.username, } if self._config.password is not None: connect_kwargs["password"] = self._config.password if self._config.private_key_path is not None: connect_kwargs["client_keys"] = [ str(self._config.private_key_path) ] if self._config.known_hosts is None: connect_kwargs["known_hosts"] = None else: connect_kwargs["known_hosts"] = str(self._config.known_hosts) return await asyncssh.connect(**connect_kwargs) def _ensure_connected(self) -> asyncssh.SSHClientConnection: """Ensure a connection is established and return it. :return: The active SSH connection. :raises SSHConnectionError: If no connection is established. """ if self._connection is None: raise SSHConnectionError( "No SSH connection established. " "Call runner.connect() first: `with runner.connect(): ...`" ) return self._connection def _ensure_loop(self) -> asyncio.AbstractEventLoop: """Return the active event loop. :return: The event loop created during :meth:`connect`. :raises SSHConnectionError: If no connection (and thus no loop) exists. """ if self._loop is None: raise SSHConnectionError( "No SSH connection established. " "Call runner.connect() first: `with runner.connect(): ...`" ) return self._loop
[docs] @contextmanager def connect(self) -> Generator[Self, None, None]: """Establish an SSH connection for the duration of the context. :yields: This runner instance with an active connection. :raises RuntimeError: If a connection is already active. """ if self._connection is not None: raise RuntimeError( "SSHCommandRunner already has an active connection. " "Cannot open a new connection while one is active." ) loop = asyncio.new_event_loop() self._loop = loop try: self._connection = loop.run_until_complete(self._connect()) except OSError as exc: loop.run_until_complete(loop.shutdown_asyncgens()) loop.close() self._loop = None if self.device_name: target = ( f"device '{self.device_name}' " f"({self._config.username}@{self._config.host}:{self._config.port})" ) else: target = f"{self._config.username}@{self._config.host}:{self._config.port}" raise SSHConnectionError( f"Could not connect to {target}. " f"Check that the device is reachable and the SSH server is running." ) from exc try: yield self finally: if self._connection is not None: self._connection.close() loop.run_until_complete(self._connection.wait_closed()) self._connection = None loop.run_until_complete(loop.shutdown_asyncgens()) loop.close() self._loop = None
async def _run( self, command: Sequence[str], cwd: RemotePath | None = None, *, hide: bool = False, ) -> CommandResult: conn = self._ensure_connected() cmd_str = " ".join(map(str, command)) if cwd is not None: cmd_str = f"cd {cwd} && {cmd_str}" host = self._config.host user = self._config.username name = self.device_name console = _console addr = f"{user}@{host}" if hide: # Compact one-liner: no box, no streaming output console.print() header = ( f"Running command on '{name}' ({addr}):" if name is not None else f"Running command on {addr}:" ) console.print(header, highlight=False) console.print( Text(f"{_INDENT}└ ", style=_BORDER_STYLE) + Text("$ ", style="bold green") + Text(cmd_str, style="bold"), highlight=False, ) result = await conn.run(cmd_str, check=True) return cast(CommandResult, result) # Box with streamed output console.print() header = ( f"Running command on '{name}' ({addr}):" if name is not None else f"Running command on {addr}:" ) console.print(header, highlight=False) _box_top(console, title=name or addr) console.print( _PREFIX + Text("$ ", style="bold green") + Text(cmd_str, style="bold"), highlight=False, ) stdout_tee = _TeeWriter(console) stderr_tee = _TeeWriter(console, style="red") try: result = await conn.run( cmd_str, check=True, stdout=stdout_tee, stderr=stderr_tee, ) except Exception: stdout_tee.flush_remaining() stderr_tee.flush_remaining() _box_bottom(console) raise stdout_tee.flush_remaining() stderr_tee.flush_remaining() _box_bottom(console) result.stdout = stdout_tee.captured result.stderr = stderr_tee.captured return cast(CommandResult, result) async def _get(self, source: RemotePath, destination: LocalPath) -> None: conn = self._ensure_connected() await asyncssh.scp((conn, str(source)), str(destination)) async def _put(self, source: LocalPath, destination: RemotePath) -> None: conn = self._ensure_connected() await asyncssh.scp(str(source), (conn, str(destination)))
[docs] def run( self, command: Sequence[str], cwd: RemotePath | None = None, *, hide: bool = False, ) -> CommandResult: """Execute a command on the remote device. :param command: The command and arguments to execute. :param cwd: Optional working directory for command execution. :param hide: If ``True``, print only a compact one-liner instead of the full bordered box with streamed output. Use this for short utility commands (e.g. ``mkdir``, ``rm``) that don't produce meaningful output. :return: The completed process result. :raises SSHConnectionError: If no connection is established. :raises asyncssh.ProcessError: If the command fails. """ return self._ensure_loop().run_until_complete( self._run(command, cwd, hide=hide) )
[docs] def get(self, source: RemotePath, destination: LocalPath) -> None: """Download a file from the remote device. :param source: The remote file path to download. :param destination: The local destination path. :raises SSHConnectionError: If no connection is established. """ self._ensure_loop().run_until_complete(self._get(source, destination))
[docs] def put(self, source: LocalPath, destination: RemotePath) -> None: """Upload a file to the remote device. :param source: The local file path to upload. :param destination: The remote destination path. :raises SSHConnectionError: If no connection is established. """ self._ensure_loop().run_until_complete(self._put(source, destination))