Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions agentlightning/execution/client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ async def _execute_algorithm(
logger.debug("Algorithm bundle starting against endpoint %s", wrapper_store.endpoint)
await algorithm(wrapper_store, stop_evt)
logger.debug("Algorithm bundle completed successfully")
except asyncio.CancelledError:
logger.debug("Algorithm received CancelledError; signaling stop event")
stop_evt.set()
raise
except KeyboardInterrupt:
logger.warning("Algorithm received KeyboardInterrupt; signaling stop event")
stop_evt.set()
Expand Down Expand Up @@ -179,6 +183,10 @@ async def _execute_runner(
logger.debug("Runner %s executing with provided store", worker_id)
await runner(client_store, worker_id, stop_evt)
logger.debug("Runner %s completed successfully", worker_id)
except asyncio.CancelledError:
logger.debug("Runner %s received CancelledError; signaling stop event", worker_id)
stop_evt.set()
raise
except KeyboardInterrupt:
logger.warning("Runner %s received KeyboardInterrupt; signaling stop event", worker_id)
stop_evt.set()
Expand Down Expand Up @@ -210,7 +218,13 @@ def _spawn_runners(
def _runner_sync(runner: RunnerBundle, worker_id: int, store: LightningStore, stop_evt: ExecutionEvent) -> None:
# Runners are executed in child processes; each process owns its own
# event loop to keep the asyncio scheduler isolated.
asyncio.run(self._execute_runner(runner, worker_id, store, stop_evt))
try:
asyncio.run(self._execute_runner(runner, worker_id, store, stop_evt))
except KeyboardInterrupt:
logger.warning("Runner (asyncio) %s received KeyboardInterrupt; exiting gracefully", worker_id)
except BaseException as exc:
logger.exception("Runner (asyncio) %s crashed by %s; signaling stop event", worker_id, exc)
raise

for i in range(self.n_runners):
process = cast(
Expand All @@ -234,7 +248,13 @@ def _spawn_algorithm_process(
"""Used when `main_process == "runner"`."""

def _algorithm_sync(algorithm: AlgorithmBundle, store: LightningStore, stop_evt: ExecutionEvent) -> None:
asyncio.run(self._execute_algorithm(algorithm, store, stop_evt))
try:
asyncio.run(self._execute_algorithm(algorithm, store, stop_evt))
except KeyboardInterrupt:
logger.warning("Algorithm (asyncio.run) received KeyboardInterrupt; exiting gracefully")
except BaseException as exc:
logger.exception("Algorithm (asyncio.run) crashed by %s; signaling stop event", exc)
raise

process = cast(
multiprocessing.Process,
Expand Down
66 changes: 66 additions & 0 deletions tests/execution/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,11 @@ async def _kbint_in_runner(store: LightningStore, worker_id: int, event: Executi
raise KeyboardInterrupt()


async def _cancel_in_runner(store: LightningStore, worker_id: int, event: ExecutionEvent) -> None:
_ = (store, worker_id, event)
raise asyncio.CancelledError()


async def _timeout_error_in_runner(store: LightningStore, worker_id: int, event: ExecutionEvent) -> None:
# Provoke client's validation (pre-request), then raise TimeoutError.
with pytest.raises(ValueError):
Expand Down Expand Up @@ -1236,3 +1241,64 @@ async def algo(store: LightningStore, event: ExecutionEvent) -> None:
assert (
len(store.calls) == initial_call_count
), "Store state should not be modified in main process when main_process='runner'"


def test_spawn_runners_handles_keyboard_interrupt_gracefully(store: LightningStore) -> None:
"""
Test that KeyboardInterrupt (Ctrl+C) is caught by _runner_sync
and results in a graceful exit (exitcode 0).
"""
strat = ClientServerExecutionStrategy(
role="runner",
n_runners=1,
server_host="127.0.0.1",
server_port=_free_port(),
)
ctx = get_context()
stop_evt: ExecutionEvent = MpEvent()

processes = strat._spawn_runners(_kbint_in_runner, store, stop_evt, ctx=ctx) # pyright: ignore[reportPrivateUsage]

try:
for p in processes:
p.join(timeout=2.0)

for p in processes:
assert not p.is_alive()
assert p.exitcode == 0, f"Runner {p.name} should exit gracefully on KeyboardInterrupt"

finally:
for p in processes:
if p.is_alive():
p.terminate()
p.join()


def test_spawn_runners_treats_cancelled_error_as_crash(store: LightningStore) -> None:
"""
Test that asyncio.CancelledError in __spawn_runners causes a crash (exitcode != 0).
"""
strat = ClientServerExecutionStrategy(
role="runner",
n_runners=1,
server_host="127.0.0.1",
server_port=_free_port(),
)
ctx = get_context()
stop_evt: ExecutionEvent = MpEvent()

processes = strat._spawn_runners(_cancel_in_runner, store, stop_evt, ctx=ctx) # pyright: ignore[reportPrivateUsage]

try:
for p in processes:
p.join(timeout=2.0)

for p in processes:
assert not p.is_alive()
assert p.exitcode != 0, f"Runner {p.name} should crash on CancelledError (exitcode={p.exitcode})"

finally:
for p in processes:
if p.is_alive():
p.terminate()
p.join()