Skip to content

Commit 56a442d

Browse files
GH-141565: Add async code awareness to Tachyon (#141533)
Co-authored-by: Pablo Galindo Salgado <[email protected]>
1 parent 35142b1 commit 56a442d

File tree

14 files changed

+1360
-88
lines changed

14 files changed

+1360
-88
lines changed

Lib/profiling/sampling/cli.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ def _add_sampling_options(parser):
195195
dest="gc",
196196
help='Don\'t include artificial "<GC>" frames to denote active garbage collection',
197197
)
198+
sampling_group.add_argument(
199+
"--async-aware",
200+
action="store_true",
201+
help="Enable async-aware profiling (uses task-based stack reconstruction)",
202+
)
198203

199204

200205
def _add_mode_options(parser):
@@ -205,7 +210,14 @@ def _add_mode_options(parser):
205210
choices=["wall", "cpu", "gil"],
206211
default="wall",
207212
help="Sampling mode: wall (all samples), cpu (only samples when thread is on CPU), "
208-
"gil (only samples when thread holds the GIL)",
213+
"gil (only samples when thread holds the GIL). Incompatible with --async-aware",
214+
)
215+
mode_group.add_argument(
216+
"--async-mode",
217+
choices=["running", "all"],
218+
default="running",
219+
help='Async profiling mode: "running" (only running task) '
220+
'or "all" (all tasks including waiting). Requires --async-aware',
209221
)
210222

211223

@@ -382,6 +394,27 @@ def _validate_args(args, parser):
382394
"Live mode requires the curses module, which is not available."
383395
)
384396

397+
# Async-aware mode is incompatible with --native, --no-gc, --mode, and --all-threads
398+
if args.async_aware:
399+
issues = []
400+
if args.native:
401+
issues.append("--native")
402+
if not args.gc:
403+
issues.append("--no-gc")
404+
if hasattr(args, 'mode') and args.mode != "wall":
405+
issues.append(f"--mode={args.mode}")
406+
if hasattr(args, 'all_threads') and args.all_threads:
407+
issues.append("--all-threads")
408+
if issues:
409+
parser.error(
410+
f"Options {', '.join(issues)} are incompatible with --async-aware. "
411+
"Async-aware profiling uses task-based stack reconstruction."
412+
)
413+
414+
# --async-mode requires --async-aware
415+
if hasattr(args, 'async_mode') and args.async_mode != "running" and not args.async_aware:
416+
parser.error("--async-mode requires --async-aware to be enabled.")
417+
385418
# Live mode is incompatible with format options
386419
if hasattr(args, 'live') and args.live:
387420
if args.format != "pstats":
@@ -570,6 +603,7 @@ def _handle_attach(args):
570603
all_threads=args.all_threads,
571604
realtime_stats=args.realtime_stats,
572605
mode=mode,
606+
async_aware=args.async_mode if args.async_aware else None,
573607
native=args.native,
574608
gc=args.gc,
575609
)
@@ -618,6 +652,7 @@ def _handle_run(args):
618652
all_threads=args.all_threads,
619653
realtime_stats=args.realtime_stats,
620654
mode=mode,
655+
async_aware=args.async_mode if args.async_aware else None,
621656
native=args.native,
622657
gc=args.gc,
623658
)
@@ -650,6 +685,7 @@ def _handle_live_attach(args, pid):
650685
limit=20, # Default limit
651686
pid=pid,
652687
mode=mode,
688+
async_aware=args.async_mode if args.async_aware else None,
653689
)
654690

655691
# Sample in live mode
@@ -660,6 +696,7 @@ def _handle_live_attach(args, pid):
660696
all_threads=args.all_threads,
661697
realtime_stats=args.realtime_stats,
662698
mode=mode,
699+
async_aware=args.async_mode if args.async_aware else None,
663700
native=args.native,
664701
gc=args.gc,
665702
)
@@ -689,6 +726,7 @@ def _handle_live_run(args):
689726
limit=20, # Default limit
690727
pid=process.pid,
691728
mode=mode,
729+
async_aware=args.async_mode if args.async_aware else None,
692730
)
693731

694732
# Profile the subprocess in live mode
@@ -700,6 +738,7 @@ def _handle_live_run(args):
700738
all_threads=args.all_threads,
701739
realtime_stats=args.realtime_stats,
702740
mode=mode,
741+
async_aware=args.async_mode if args.async_aware else None,
703742
native=args.native,
704743
gc=args.gc,
705744
)

Lib/profiling/sampling/collector.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
from .constants import (
33
THREAD_STATUS_HAS_GIL,
44
THREAD_STATUS_ON_CPU,
5-
THREAD_STATUS_UNKNOWN,
65
THREAD_STATUS_GIL_REQUESTED,
6+
THREAD_STATUS_UNKNOWN,
77
)
88

9+
try:
10+
from _remote_debugging import FrameInfo
11+
except ImportError:
12+
# Fallback definition if _remote_debugging is not available
13+
FrameInfo = None
14+
915
class Collector(ABC):
1016
@abstractmethod
1117
def collect(self, stack_frames):
@@ -33,6 +39,95 @@ def _iter_all_frames(self, stack_frames, skip_idle=False):
3339
if frames:
3440
yield frames, thread_info.thread_id
3541

42+
def _iter_async_frames(self, awaited_info_list):
43+
# Phase 1: Index tasks and build parent relationships with pre-computed selection
44+
task_map, child_to_parent, all_task_ids, all_parent_ids = self._build_task_graph(awaited_info_list)
45+
46+
# Phase 2: Find leaf tasks (tasks not awaited by anyone)
47+
leaf_task_ids = self._find_leaf_tasks(all_task_ids, all_parent_ids)
48+
49+
# Phase 3: Build linear stacks from each leaf to root (optimized - no sorting!)
50+
yield from self._build_linear_stacks(leaf_task_ids, task_map, child_to_parent)
51+
52+
def _build_task_graph(self, awaited_info_list):
53+
task_map = {}
54+
child_to_parent = {} # Maps child_id -> (selected_parent_id, parent_count)
55+
all_task_ids = set()
56+
all_parent_ids = set() # Track ALL parent IDs for leaf detection
57+
58+
for awaited_info in awaited_info_list:
59+
thread_id = awaited_info.thread_id
60+
for task_info in awaited_info.awaited_by:
61+
task_id = task_info.task_id
62+
task_map[task_id] = (task_info, thread_id)
63+
all_task_ids.add(task_id)
64+
65+
# Pre-compute selected parent and count for optimization
66+
if task_info.awaited_by:
67+
parent_ids = [p.task_name for p in task_info.awaited_by]
68+
parent_count = len(parent_ids)
69+
# Track ALL parents for leaf detection
70+
all_parent_ids.update(parent_ids)
71+
# Use min() for O(n) instead of sorted()[0] which is O(n log n)
72+
selected_parent = min(parent_ids) if parent_count > 1 else parent_ids[0]
73+
child_to_parent[task_id] = (selected_parent, parent_count)
74+
75+
return task_map, child_to_parent, all_task_ids, all_parent_ids
76+
77+
def _find_leaf_tasks(self, all_task_ids, all_parent_ids):
78+
# Leaves are tasks that are not parents of any other task
79+
return all_task_ids - all_parent_ids
80+
81+
def _build_linear_stacks(self, leaf_task_ids, task_map, child_to_parent):
82+
for leaf_id in leaf_task_ids:
83+
frames = []
84+
visited = set()
85+
current_id = leaf_id
86+
thread_id = None
87+
88+
# Follow the single parent chain from leaf to root
89+
while current_id is not None:
90+
# Cycle detection
91+
if current_id in visited:
92+
break
93+
visited.add(current_id)
94+
95+
# Check if task exists in task_map
96+
if current_id not in task_map:
97+
break
98+
99+
task_info, tid = task_map[current_id]
100+
101+
# Set thread_id from first task
102+
if thread_id is None:
103+
thread_id = tid
104+
105+
# Add all frames from all coroutines in this task
106+
if task_info.coroutine_stack:
107+
for coro_info in task_info.coroutine_stack:
108+
for frame in coro_info.call_stack:
109+
frames.append(frame)
110+
111+
# Get pre-computed parent info (no sorting needed!)
112+
parent_info = child_to_parent.get(current_id)
113+
114+
# Add task boundary marker with parent count annotation if multiple parents
115+
task_name = task_info.task_name or "Task-" + str(task_info.task_id)
116+
if parent_info:
117+
selected_parent, parent_count = parent_info
118+
if parent_count > 1:
119+
task_name = f"{task_name} ({parent_count} parents)"
120+
frames.append(FrameInfo(("<task>", 0, task_name)))
121+
current_id = selected_parent
122+
else:
123+
# Root task - no parent
124+
frames.append(FrameInfo(("<task>", 0, task_name)))
125+
current_id = None
126+
127+
# Yield the complete stack if we collected any frames
128+
if frames and thread_id is not None:
129+
yield frames, thread_id, leaf_id
130+
36131
def _is_gc_frame(self, frame):
37132
if isinstance(frame, tuple):
38133
funcname = frame[2] if len(frame) >= 3 else ""

Lib/profiling/sampling/live_collector/collector.py

Lines changed: 46 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
pid=None,
104104
display=None,
105105
mode=None,
106+
async_aware=None,
106107
):
107108
"""
108109
Initialize the live stats collector.
@@ -115,6 +116,7 @@ def __init__(
115116
pid: Process ID being profiled
116117
display: DisplayInterface implementation (None means curses will be used)
117118
mode: Profiling mode ('cpu', 'gil', etc.) - affects what stats are shown
119+
async_aware: Async tracing mode - None (sync only), "all" or "running"
118120
"""
119121
self.result = collections.defaultdict(
120122
lambda: dict(total_rec_calls=0, direct_calls=0, cumulative_calls=0)
@@ -133,6 +135,9 @@ def __init__(
133135
self.running = True
134136
self.pid = pid
135137
self.mode = mode # Profiling mode
138+
self.async_aware = async_aware # Async tracing mode
139+
# Pre-select frame iterator method to avoid per-call dispatch overhead
140+
self._get_frame_iterator = self._get_async_frame_iterator if async_aware else self._get_sync_frame_iterator
136141
self._saved_stdout = None
137142
self._saved_stderr = None
138143
self._devnull = None
@@ -294,6 +299,15 @@ def process_frames(self, frames, thread_id=None):
294299
if thread_data:
295300
thread_data.result[top_location]["direct_calls"] += 1
296301

302+
def _get_sync_frame_iterator(self, stack_frames):
303+
"""Iterator for sync frames."""
304+
return self._iter_all_frames(stack_frames, skip_idle=self.skip_idle)
305+
306+
def _get_async_frame_iterator(self, stack_frames):
307+
"""Iterator for async frames, yielding (frames, thread_id) tuples."""
308+
for frames, thread_id, task_id in self._iter_async_frames(stack_frames):
309+
yield frames, thread_id
310+
297311
def collect_failed_sample(self):
298312
self.failed_samples += 1
299313
self.total_samples += 1
@@ -304,78 +318,40 @@ def collect(self, stack_frames):
304318
self.start_time = time.perf_counter()
305319
self._last_display_update = self.start_time
306320

307-
# Thread status counts for this sample
308-
temp_status_counts = {
309-
"has_gil": 0,
310-
"on_cpu": 0,
311-
"gil_requested": 0,
312-
"unknown": 0,
313-
"total": 0,
314-
}
315321
has_gc_frame = False
316322

317-
# Always collect data, even when paused
318-
# Track thread status flags and GC frames
319-
for interpreter_info in stack_frames:
320-
threads = getattr(interpreter_info, "threads", [])
321-
for thread_info in threads:
322-
temp_status_counts["total"] += 1
323-
324-
# Track thread status using bit flags
325-
status_flags = getattr(thread_info, "status", 0)
326-
thread_id = getattr(thread_info, "thread_id", None)
327-
328-
# Update aggregated counts
329-
if status_flags & THREAD_STATUS_HAS_GIL:
330-
temp_status_counts["has_gil"] += 1
331-
if status_flags & THREAD_STATUS_ON_CPU:
332-
temp_status_counts["on_cpu"] += 1
333-
if status_flags & THREAD_STATUS_GIL_REQUESTED:
334-
temp_status_counts["gil_requested"] += 1
335-
if status_flags & THREAD_STATUS_UNKNOWN:
336-
temp_status_counts["unknown"] += 1
337-
338-
# Update per-thread status counts
339-
if thread_id is not None:
340-
thread_data = self._get_or_create_thread_data(thread_id)
341-
thread_data.increment_status_flag(status_flags)
342-
343-
# Process frames (respecting skip_idle)
344-
if self.skip_idle:
345-
has_gil = bool(status_flags & THREAD_STATUS_HAS_GIL)
346-
on_cpu = bool(status_flags & THREAD_STATUS_ON_CPU)
347-
if not (has_gil or on_cpu):
348-
continue
349-
350-
frames = getattr(thread_info, "frame_info", None)
351-
if frames:
352-
self.process_frames(frames, thread_id=thread_id)
353-
354-
# Track thread IDs only for threads that actually have samples
355-
if (
356-
thread_id is not None
357-
and thread_id not in self.thread_ids
358-
):
359-
self.thread_ids.append(thread_id)
360-
361-
# Increment per-thread sample count and check for GC frames
362-
thread_has_gc_frame = False
363-
for frame in frames:
364-
funcname = getattr(frame, "funcname", "")
365-
if "<GC>" in funcname or "gc_collect" in funcname:
366-
has_gc_frame = True
367-
thread_has_gc_frame = True
368-
break
369-
370-
if thread_id is not None:
371-
thread_data = self._get_or_create_thread_data(thread_id)
372-
thread_data.sample_count += 1
373-
if thread_has_gc_frame:
374-
thread_data.gc_frame_samples += 1
375-
376-
# Update cumulative thread status counts
377-
for key, count in temp_status_counts.items():
378-
self.thread_status_counts[key] += count
323+
# Collect thread status stats (only available in sync mode)
324+
if not self.async_aware:
325+
status_counts, sample_has_gc, per_thread_stats = self._collect_thread_status_stats(stack_frames)
326+
for key, count in status_counts.items():
327+
self.thread_status_counts[key] += count
328+
if sample_has_gc:
329+
has_gc_frame = True
330+
331+
for thread_id, stats in per_thread_stats.items():
332+
thread_data = self._get_or_create_thread_data(thread_id)
333+
thread_data.has_gil += stats.get("has_gil", 0)
334+
thread_data.on_cpu += stats.get("on_cpu", 0)
335+
thread_data.gil_requested += stats.get("gil_requested", 0)
336+
thread_data.unknown += stats.get("unknown", 0)
337+
thread_data.total += stats.get("total", 0)
338+
if stats.get("gc_samples", 0):
339+
thread_data.gc_frame_samples += stats["gc_samples"]
340+
341+
# Process frames using pre-selected iterator
342+
for frames, thread_id in self._get_frame_iterator(stack_frames):
343+
if not frames:
344+
continue
345+
346+
self.process_frames(frames, thread_id=thread_id)
347+
348+
# Track thread IDs
349+
if thread_id is not None and thread_id not in self.thread_ids:
350+
self.thread_ids.append(thread_id)
351+
352+
if thread_id is not None:
353+
thread_data = self._get_or_create_thread_data(thread_id)
354+
thread_data.sample_count += 1
379355

380356
if has_gc_frame:
381357
self.gc_frame_samples += 1

Lib/profiling/sampling/pstats_collector.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,14 @@ def _process_frames(self, frames):
4242
self.callers[callee][caller] += 1
4343

4444
def collect(self, stack_frames):
45-
for frames, thread_id in self._iter_all_frames(stack_frames, skip_idle=self.skip_idle):
46-
self._process_frames(frames)
45+
if stack_frames and hasattr(stack_frames[0], "awaited_by"):
46+
# Async frame processing
47+
for frames, thread_id, task_id in self._iter_async_frames(stack_frames):
48+
self._process_frames(frames)
49+
else:
50+
# Regular frame processing
51+
for frames, thread_id in self._iter_all_frames(stack_frames, skip_idle=self.skip_idle):
52+
self._process_frames(frames)
4753

4854
def export(self, filename):
4955
self.create_stats()

0 commit comments

Comments
 (0)