Source code for embedl_hub._internal.core.device.ssh

# Copyright (C) 2026 Embedl AB

"""SSH-based command runner for remote device execution.

This module provides an SSH command runner implementation using asyncssh
for executing commands and transferring files on remote devices.
"""

import asyncio
import io
import os
import warnings
from collections.abc import Generator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from typing import cast

import asyncssh
from rich.console import Console
from rich.text import Text
from typing_extensions import Self

from embedl_hub._internal.core.device.abc import (
    CommandResult,
    CommandRunner,
    Connectable,
)
from embedl_hub._internal.core.types import LocalPath, RemotePath


def _ssh_agent_available() -> bool:
    """Return True if an SSH agent socket is present in the environment."""
    return bool(os.environ.get("SSH_AUTH_SOCK"))


[docs] class SSHConnectionError(Exception): """Raised when an SSH operation is attempted without an established connection."""
[docs] @dataclass class SSHConfig: """Configuration for SSH connections. :param host: The hostname or IP address of the remote device. :param port: The SSH port (default: 22). :param username: The username for authentication. :param password: Optional password for authentication. Using password authentication is discouraged for security reasons and will raise a warning. Prefer key-based authentication instead. :param private_key_path: Path to a specific private key file. Overrides all other authentication methods — the SSH agent and default key discovery are both skipped. Only needed in unusual setups where a non-standard key path must be forced. :param known_hosts: Path to known_hosts file, or None to disable checking. """ host: str username: str port: int = 22 password: str | None = None private_key_path: LocalPath | None = None known_hosts: LocalPath | None = None
_console = Console(stderr=True) """Shared Rich console for SSH output (writes to stderr so it doesn't mix with real program stdout when piped).""" _BORDER_STYLE = "dim cyan" _INDENT = " " _PREFIX = Text(f"{_INDENT}│ ", style="dim cyan") def _box_top(console: Console, title: str | None = None) -> None: """Print the top edge of a box: `` ┌─── title ───…``.""" width = console.width - len(_INDENT) - 1 # space for indent + ┌ if title: label = f" {title} " left = max((width - len(label)) // 2, 1) right = max(width - left - len(label), 1) line = f"{_INDENT}{'─' * left}{label}{'─' * right}" else: line = f"{_INDENT}{'─' * width}" console.print(Text(line, style=_BORDER_STYLE), highlight=False) def _box_bottom(console: Console) -> None: """Print the bottom edge of a box: `` └───…``.""" width = console.width - len(_INDENT) - 1 console.print( Text(f"{_INDENT}{'─' * width}", style=_BORDER_STYLE), highlight=False, ) class _TeeWriter(io.RawIOBase): """Captures bytes *and* prints each line to the console with a prefix. asyncssh closes file objects passed as stdout/stderr, so this wrapper prevents the real stream from being closed while also accumulating all written bytes so they can be retrieved afterwards. """ def __init__(self, console: Console, style: str = "") -> None: self._console = console self._style = style self._parts: list[bytes] = [] self._line_buf = "" def write(self, b: bytes | bytearray) -> int: # type: ignore[override] data = bytes(b) self._parts.append(data) text = self._line_buf + data.decode("utf-8", errors="replace") lines = text.split("\n") # Last element is the incomplete line (may be empty) self._line_buf = lines.pop() for line in lines: self._console.print( _PREFIX + Text(line, style=self._style), highlight=False, ) return len(data) def flush_remaining(self) -> None: """Print any remaining partial line in the buffer.""" if self._line_buf: self._console.print( _PREFIX + Text(self._line_buf, style=self._style), highlight=False, ) self._line_buf = "" @property def captured(self) -> str: """Return all captured output decoded as UTF-8.""" return b"".join(self._parts).decode("utf-8", errors="replace") def close(self) -> None: # Intentionally do NOT close the underlying stream. pass
[docs] class SSHCommandRunner(CommandRunner, Connectable): """SSH-based command runner for remote device execution. Implements the :class:`CommandRunner` and :class:`Connectable` protocols using asyncssh for SSH connections. Call :meth:`connect` to obtain a context manager that establishes the connection. Example:: runner = SSHCommandRunner(config) with runner.connect(): runner.run(["ls", "-la"]) """ device_name: str | None """Human-readable device name, set by the owning :class:`~embedl_hub._internal.core.device.abc.Device`.""" def __init__(self, config: SSHConfig) -> None: """Initialize the SSH command runner. :param config: SSH connection configuration. """ if config.password is not None: warnings.warn( "Using password authentication is discouraged for security reasons. " "Consider using key-based authentication instead.", UserWarning, stacklevel=2, ) self._config = config self._connection: asyncssh.SSHClientConnection | None = None self._loop: asyncio.AbstractEventLoop | None = None self.device_name: str | None = None @property def is_active(self) -> bool: """Whether an SSH connection is currently established.""" return self._connection is not None async def _connect(self) -> asyncssh.SSHClientConnection: """Establish an SSH connection to the remote device. :return: An active SSH client connection. """ connect_kwargs: dict[str, object] = { "host": self._config.host, "port": self._config.port, "username": self._config.username, } if self._config.password is not None: connect_kwargs["password"] = self._config.password if self._config.private_key_path is not None: connect_kwargs["client_keys"] = [ str(self._config.private_key_path) ] elif _ssh_agent_available(): connect_kwargs["agent_path"] = os.environ.get("SSH_AUTH_SOCK") connect_kwargs["client_keys"] = [] if self._config.known_hosts is None: connect_kwargs["known_hosts"] = None else: connect_kwargs["known_hosts"] = str(self._config.known_hosts) return await asyncssh.connect(**connect_kwargs) def _ensure_connected(self) -> asyncssh.SSHClientConnection: """Ensure a connection is established and return it. :return: The active SSH connection. :raises SSHConnectionError: If no connection is established. """ if self._connection is None: raise SSHConnectionError( "No SSH connection established. " "Call runner.connect() first: `with runner.connect(): ...`" ) return self._connection def _ensure_loop(self) -> asyncio.AbstractEventLoop: """Return the active event loop. :return: The event loop created during :meth:`connect`. :raises SSHConnectionError: If no connection (and thus no loop) exists. """ if self._loop is None: raise SSHConnectionError( "No SSH connection established. " "Call runner.connect() first: `with runner.connect(): ...`" ) return self._loop
[docs] @contextmanager def connect(self) -> Generator[Self, None, None]: """Establish an SSH connection for the duration of the context. :yields: This runner instance with an active connection. :raises RuntimeError: If a connection is already active. :raises SSHConnectionError: If the connection or authentication fails. """ if self._connection is not None: raise RuntimeError( "SSHCommandRunner already has an active connection. " "Cannot open a new connection while one is active." ) # Validate private key path before allocating any resources. if self._config.private_key_path is not None: if not os.path.exists(str(self._config.private_key_path)): raise SSHConnectionError( f"Private key file not found: '{self._config.private_key_path}'. " "Check the path passed to SSHConfig(private_key_path=...)." ) loop = asyncio.new_event_loop() self._loop = loop def _cleanup_loop() -> None: loop.run_until_complete(loop.shutdown_asyncgens()) loop.close() self._loop = None def _target_str() -> str: if self.device_name: return ( f"device '{self.device_name}' " f"({self._config.username}@{self._config.host}:{self._config.port})" ) return f"{self._config.username}@{self._config.host}:{self._config.port}" try: self._connection = loop.run_until_complete(self._connect()) except OSError as exc: _cleanup_loop() raise SSHConnectionError( f"Could not connect to {_target_str()}. " f"Check that the device is reachable and the SSH server is running." ) from exc except asyncssh.KeyImportError as exc: _cleanup_loop() raise SSHConnectionError( "Could not load SSH private keys. " "If you use a hardware security key (YubiKey or similar), " "make sure your SSH agent is running and SSH_AUTH_SOCK is set " "in the environment where you run this script." ) from exc except asyncssh.PermissionDenied as exc: _cleanup_loop() raise SSHConnectionError( f"SSH authentication failed for {_target_str()}. " "Check that the username is correct and the key is authorised " "on the device (see ~/.ssh/authorized_keys)." ) from exc try: yield self finally: if self._connection is not None: self._connection.close() loop.run_until_complete(self._connection.wait_closed()) self._connection = None loop.run_until_complete(loop.shutdown_asyncgens()) loop.close() self._loop = None
async def _run( self, command: Sequence[str], cwd: RemotePath | None = None, *, hide: bool = False, ) -> CommandResult: conn = self._ensure_connected() cmd_str = " ".join(map(str, command)) if cwd is not None: cmd_str = f"cd {cwd} && {cmd_str}" host = self._config.host user = self._config.username name = self.device_name console = _console addr = f"{user}@{host}" if hide: # Compact one-liner: no box, no streaming output console.print() header = ( f"Running command on '{name}' ({addr}):" if name is not None else f"Running command on {addr}:" ) console.print(header, highlight=False) console.print( Text(f"{_INDENT}└ ", style=_BORDER_STYLE) + Text("$ ", style="bold green") + Text(cmd_str, style="bold"), highlight=False, ) result = await conn.run(cmd_str, check=True) return cast(CommandResult, result) # Box with streamed output console.print() header = ( f"Running command on '{name}' ({addr}):" if name is not None else f"Running command on {addr}:" ) console.print(header, highlight=False) _box_top(console, title=name or addr) console.print( _PREFIX + Text("$ ", style="bold green") + Text(cmd_str, style="bold"), highlight=False, ) stdout_tee = _TeeWriter(console) stderr_tee = _TeeWriter(console, style="red") try: result = await conn.run( cmd_str, check=True, stdout=stdout_tee, stderr=stderr_tee, ) except asyncssh.ProcessError as exc: stdout_tee.flush_remaining() stderr_tee.flush_remaining() _box_bottom(console) # asyncssh.ProcessError.stdout/stderr may be empty when output # was streamed to our TeeWriter. Patch the captured text back # onto the exception so callers (e.g. run_remote_command) can # surface it in their error messages. exc.stdout = stdout_tee.captured exc.stderr = stderr_tee.captured raise except Exception: stdout_tee.flush_remaining() stderr_tee.flush_remaining() _box_bottom(console) raise stdout_tee.flush_remaining() stderr_tee.flush_remaining() _box_bottom(console) result.stdout = stdout_tee.captured result.stderr = stderr_tee.captured return cast(CommandResult, result) async def _get(self, source: RemotePath, destination: LocalPath) -> None: conn = self._ensure_connected() await asyncssh.scp((conn, str(source)), str(destination)) async def _put(self, source: LocalPath, destination: RemotePath) -> None: conn = self._ensure_connected() await asyncssh.scp(str(source), (conn, str(destination)))
[docs] def run( self, command: Sequence[str], cwd: RemotePath | None = None, *, hide: bool = False, ) -> CommandResult: """Execute a command on the remote device. :param command: The command and arguments to execute. :param cwd: Optional working directory for command execution. :param hide: If ``True``, print only a compact one-liner instead of the full bordered box with streamed output. Use this for short utility commands (e.g. ``mkdir``, ``rm``) that don't produce meaningful output. :return: The completed process result. :raises SSHConnectionError: If no connection is established. :raises asyncssh.ProcessError: If the command fails. """ return self._ensure_loop().run_until_complete( self._run(command, cwd, hide=hide) )
[docs] def get(self, source: RemotePath, destination: LocalPath) -> None: """Download a file from the remote device. :param source: The remote file path to download. :param destination: The local destination path. :raises SSHConnectionError: If no connection is established. """ self._ensure_loop().run_until_complete(self._get(source, destination))
[docs] def put(self, source: LocalPath, destination: RemotePath) -> None: """Upload a file to the remote device. :param source: The local file path to upload. :param destination: The remote destination path. :raises SSHConnectionError: If no connection is established. """ self._ensure_loop().run_until_complete(self._put(source, destination))