Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions src/pyssm_client/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def exec_command(
endpoint_url: str | None,
timeout: int,
) -> None:
"""Execute a single command and return stdout/stderr/exit code."""
"""Execute a single command with real-time streaming output."""
try:
result = run_command_sync(
target=target,
Expand All @@ -401,22 +401,10 @@ def exec_command(
region=region,
endpoint_url=endpoint_url,
timeout=timeout,
stream_output=True,
)
# Print streams; exit with command exit code
if result.stdout:
try:
sys.stdout.buffer.write(result.stdout)
sys.stdout.buffer.flush()
except Exception:
click.echo(result.stdout.decode("utf-8", errors="replace"), nl=False)
if result.stderr:
try:
sys.stderr.buffer.write(result.stderr)
sys.stderr.buffer.flush()
except Exception:
click.echo(
result.stderr.decode("utf-8", errors="replace"), nl=False, err=True
)
# Output is automatically streamed during execution
# Just exit with the command's exit code
sys.exit(result.exit_code)
except Exception as e:
click.echo(f"Execution failed: {e}", err=True)
Expand Down
13 changes: 2 additions & 11 deletions src/pyssm_client/file_transfer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,10 @@ async def _create_ssm_session(

async def _setup_data_channel(self, session_data: dict) -> tuple[Any, Any]:
"""Set up data channel for file transfer."""
from ..communicator.data_channel import SessionDataChannel
from ..cli.types import ConnectArguments
from ..session.registry import get_session_registry
from ..communicator.data_channel import SessionDataChannel
from ..session.plugins import StandardStreamPlugin
from ..session.registry import get_session_registry
from ..session.session_handler import SessionHandler

# Create session object first
Expand Down Expand Up @@ -507,15 +507,6 @@ async def _upload_base64(
# Small delay to avoid overwhelming remote
await asyncio.sleep(0.005)

if options.progress_callback:
try:
import sys

sys.stdout.write("\n")
sys.stdout.flush()
except Exception:
pass

self.logger.info("All chunks sent; waiting for remote flush...")
remote_size = await self._wait_for_remote_size(
temp_remote,
Expand Down
67 changes: 65 additions & 2 deletions src/pyssm_client/utils/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ async def run_command(
region: Optional[str] = None,
endpoint_url: Optional[str] = None,
timeout: int = 600,
stream_output: bool = False,
) -> CommandResult:
"""Execute a single command on target and return stdout/stderr/exit_code.

Expand All @@ -86,6 +87,7 @@ async def run_command(
region: AWS region
endpoint_url: Custom AWS endpoint URL
timeout: Command timeout in seconds
stream_output: Whether to stream filtered output to stdout/stderr in real-time

Returns:
CommandResult with separated stdout, stderr, and exit code
Expand Down Expand Up @@ -146,9 +148,40 @@ async def run_command(
session_done = asyncio.Event()
exit_code = 0

# Line buffers for streaming
stdout_line_buf = bytearray()
stderr_line_buf = bytearray()

def handle_stdout(data: bytes) -> None:
"""Handle stdout from shell."""
nonlocal stdout_buf, exit_code
nonlocal stdout_buf, exit_code, stdout_line_buf

# Add to line buffer for streaming
stdout_line_buf.extend(data)

# Process complete lines for streaming
while b"\n" in stdout_line_buf:
line_end = stdout_line_buf.index(b"\n")
line = stdout_line_buf[: line_end + 1]
stdout_line_buf = stdout_line_buf[line_end + 1 :]

# Apply existing filter to the line and stream if requested
if stream_output:
filtered = _filter_shell_output(line, command)
if filtered and filtered.strip():
try:
import sys

sys.stdout.buffer.write(filtered)
sys.stdout.buffer.flush()
except Exception:
try:
import sys

sys.stdout.write(filtered.decode("utf-8", errors="replace"))
sys.stdout.flush()
except Exception:
pass

# Check for exit status marker in stdout
try:
Expand All @@ -172,7 +205,35 @@ def handle_stdout(data: bytes) -> None:

def handle_stderr(data: bytes) -> None:
"""Handle stderr from shell."""
nonlocal stderr_buf
nonlocal stderr_buf, stderr_line_buf

# Add to line buffer for streaming
stderr_line_buf.extend(data)

# Process complete lines for streaming
while b"\n" in stderr_line_buf:
line_end = stderr_line_buf.index(b"\n")
line = stderr_line_buf[: line_end + 1]
stderr_line_buf = stderr_line_buf[line_end + 1 :]

# Apply existing filter to stderr line and stream if requested
if stream_output:
filtered = _filter_shell_output(line, command)
if filtered and filtered.strip():
try:
import sys

sys.stderr.buffer.write(filtered)
sys.stderr.buffer.flush()
except Exception:
try:
import sys

sys.stderr.write(filtered.decode("utf-8", errors="replace"))
sys.stderr.flush()
except Exception:
pass

stderr_buf.extend(data)

def handle_closed() -> None:
Expand Down Expand Up @@ -250,6 +311,7 @@ def run_command_sync(
region: Optional[str] = None,
endpoint_url: Optional[str] = None,
timeout: int = 600,
stream_output: bool = False,
) -> CommandResult:
"""Synchronous wrapper for run_command()."""
return asyncio.run(
Expand All @@ -260,5 +322,6 @@ def run_command_sync(
region=region,
endpoint_url=endpoint_url,
timeout=timeout,
stream_output=stream_output,
)
)