Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 90 additions & 40 deletions Lib/profiling/sampling/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __call__(self, parser, namespace, values, option_string=None):
_PROCESS_KILL_TIMEOUT_SEC = 2.0
_READY_MESSAGE = b"ready"
_RECV_BUFFER_SIZE = 1024
_BINARY_PROFILE_HEADER_SIZE = 64
_BINARY_PROFILE_MAGICS = (b"HCAT", b"TACH")

# Format configuration
FORMAT_EXTENSIONS = {
Expand Down Expand Up @@ -650,6 +652,88 @@ def _open_in_browser(path):
print(f"Warning: Could not open browser: {e}", file=sys.stderr)


def _validate_replay_input_file(filename):
"""Validate that the replay input looks like a sampling binary profile."""
try:
with open(filename, "rb") as file:
header = file.read(_BINARY_PROFILE_HEADER_SIZE)
except OSError as exc:
sys.exit(f"Error: Could not read input file {filename}: {exc}")

if (
len(header) < _BINARY_PROFILE_HEADER_SIZE
or header[:4] not in _BINARY_PROFILE_MAGICS
):
sys.exit(
"Error: Input file is not a binary sampling profile. "
"The replay command only accepts files created with --binary"
)


def _replay_with_reader(args, reader):
"""Replay samples from an open binary reader."""
info = reader.get_info()
interval = info['sample_interval_us']

print(f"Replaying {info['sample_count']} samples from {args.input_file}")
print(f" Sample interval: {interval} us")
print(
" Compression: "
f"{'zstd' if info.get('compression_type', 0) == 1 else 'none'}"
)

collector = _create_collector(
args.format, interval, skip_idle=False,
diff_baseline=args.diff_baseline
)

def progress_callback(current, total):
if total > 0:
pct = current / total
bar_width = 40
filled = int(bar_width * pct)
bar = '█' * filled + '░' * (bar_width - filled)
print(
f"\r [{bar}] {pct*100:5.1f}% ({current:,}/{total:,})",
end="",
flush=True,
)

count = reader.replay_samples(collector, progress_callback)
print()

if args.format == "pstats":
if args.outfile:
collector.export(args.outfile)
else:
sort_choice = (
args.sort if args.sort is not None else "nsamples"
)
limit = args.limit if args.limit is not None else 15
sort_mode = _sort_to_mode(sort_choice)
collector.print_stats(
sort_mode, limit, not args.no_summary,
PROFILING_MODE_WALL
)
else:
filename = (
args.outfile
or _generate_output_filename(args.format, os.getpid())
)
collector.export(filename)

# Auto-open browser for HTML output if --browser flag is set
if (
args.format in (
'flamegraph', 'diff_flamegraph', 'heatmap'
)
and getattr(args, 'browser', False)
):
_open_in_browser(filename)

print(f"Replayed {count} samples")


def _handle_output(collector, args, pid, mode):
"""Handle output for the collector based on format and arguments.

Expand Down Expand Up @@ -1201,47 +1285,13 @@ def _handle_replay(args):
if not os.path.exists(args.input_file):
sys.exit(f"Error: Input file not found: {args.input_file}")

with BinaryReader(args.input_file) as reader:
info = reader.get_info()
interval = info['sample_interval_us']
_validate_replay_input_file(args.input_file)

print(f"Replaying {info['sample_count']} samples from {args.input_file}")
print(f" Sample interval: {interval} us")
print(f" Compression: {'zstd' if info.get('compression_type', 0) == 1 else 'none'}")

collector = _create_collector(
args.format, interval, skip_idle=False,
diff_baseline=args.diff_baseline
)

def progress_callback(current, total):
if total > 0:
pct = current / total
bar_width = 40
filled = int(bar_width * pct)
bar = '█' * filled + '░' * (bar_width - filled)
print(f"\r [{bar}] {pct*100:5.1f}% ({current:,}/{total:,})", end="", flush=True)

count = reader.replay_samples(collector, progress_callback)
print()

if args.format == "pstats":
if args.outfile:
collector.export(args.outfile)
else:
sort_choice = args.sort if args.sort is not None else "nsamples"
limit = args.limit if args.limit is not None else 15
sort_mode = _sort_to_mode(sort_choice)
collector.print_stats(sort_mode, limit, not args.no_summary, PROFILING_MODE_WALL)
else:
filename = args.outfile or _generate_output_filename(args.format, os.getpid())
collector.export(filename)

# Auto-open browser for HTML output if --browser flag is set
if args.format in ('flamegraph', 'diff_flamegraph', 'heatmap') and getattr(args, 'browser', False):
_open_in_browser(filename)

print(f"Replayed {count} samples")
try:
with BinaryReader(args.input_file) as reader:
_replay_with_reader(args, reader)
except (OSError, ValueError) as exc:
sys.exit(f"Error: {exc}")


if __name__ == "__main__":
Expand Down
37 changes: 37 additions & 0 deletions Lib/test/test_profiling/test_sampling_profiler/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Tests for sampling profiler CLI argument parsing and functionality."""

import io
import os
import subprocess
import sys
import tempfile
import unittest
from unittest import mock

Expand Down Expand Up @@ -722,3 +724,38 @@ def test_cli_attach_nonexistent_pid(self):
main()

self.assertIn(fake_pid, str(cm.exception))

def test_cli_replay_rejects_non_binary_profile(self):
with tempfile.TemporaryDirectory() as tempdir:
profile = os.path.join(tempdir, "output.prof")
with open(profile, "wb") as file:
file.write(b"not a binary sampling profile")

with mock.patch("sys.argv", ["profiling.sampling.cli", "replay", profile]):
with self.assertRaises(SystemExit) as cm:
main()

error = str(cm.exception)
self.assertIn("not a binary sampling profile", error)
self.assertIn("--binary", error)

def test_cli_replay_reader_errors_exit_cleanly(self):
with tempfile.TemporaryDirectory() as tempdir:
profile = os.path.join(tempdir, "output.bin")
with open(profile, "wb") as file:
file.write(b"HCAT" + (b"\0" * 60))

with (
mock.patch("sys.argv", ["profiling.sampling.cli", "replay", profile]),
mock.patch(
"profiling.sampling.cli.BinaryReader",
side_effect=ValueError("Unsupported format version 2"),
),
):
with self.assertRaises(SystemExit) as cm:
main()

self.assertEqual(
str(cm.exception),
"Error: Unsupported format version 2",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The :mod:`profiling.sampling` ``replay`` command now rejects non-binary
profile files with a clear error explaining that replay only accepts files
created with ``--binary``.
Loading