From 015fc6a8435d8773b92f1c3fa9f0993a8c554b4d Mon Sep 17 00:00:00 2001 From: Hector Date: Mon, 10 Nov 2025 17:51:20 +0000 Subject: [PATCH] Add RunWithSubprocess --- requirements-dev.txt | 1 + src/launchpad/kafka.py | 314 ++++++++++++++++++-------------- tests/integration/test_kafka.py | 211 +++++++++++++++++++++ 3 files changed, 391 insertions(+), 135 deletions(-) create mode 100644 tests/integration/test_kafka.py diff --git a/requirements-dev.txt b/requirements-dev.txt index 89cf9dd6..5bdc3937 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,6 +17,7 @@ pytest-xdist>=3.8.0 pytest-benchmark>=5.1.0 pytest>=8.4.2 responses>=0.25.8 +time-machine>=2.15.0 ruff>=0.12.12 safety>=3.6.1 ty==0.0.1a20 diff --git a/src/launchpad/kafka.py b/src/launchpad/kafka.py index 0077fea7..0079c1ad 100644 --- a/src/launchpad/kafka.py +++ b/src/launchpad/kafka.py @@ -5,28 +5,23 @@ import logging import multiprocessing import os -import signal import time from dataclasses import dataclass -from functools import partial from logging.handlers import QueueHandler, QueueListener -from multiprocessing.pool import Pool -from typing import Any, Callable, Mapping +from multiprocessing.connection import Connection +from typing import Any, Callable, Generic, Mapping, TypeVar, Union -from arroyo import Message, Topic, configure_metrics +from arroyo import Topic, configure_metrics from arroyo.backends.kafka import KafkaConsumer as ArroyoKafkaConsumer from arroyo.backends.kafka import KafkaPayload +from arroyo.dlq import InvalidMessage from arroyo.processing.processor import StreamProcessor from arroyo.processing.strategies import ProcessingStrategy, ProcessingStrategyFactory +from arroyo.processing.strategies.abstract import MessageRejected from arroyo.processing.strategies.commit import CommitOffsets from arroyo.processing.strategies.healthcheck import Healthcheck -from arroyo.processing.strategies.run_task_with_multiprocessing import ( - MultiprocessingPool, - RunTaskWithMultiprocessing, - parallel_worker_initializer, -) -from arroyo.types import Commit, FilteredPayload, Partition, TStrategyPayload +from arroyo.types import Commit, FilteredPayload, Message, Partition, TStrategyPayload from sentry_kafka_schemas import get_codec from launchpad.artifact_processor import ArtifactProcessor @@ -35,51 +30,178 @@ from launchpad.utils.arroyo_metrics import DatadogMetricsBackend from launchpad.utils.logging import get_logger +TResult = TypeVar("TResult") + logger = get_logger(__name__) # Schema codec for preprod artifact events PREPROD_ARTIFACT_SCHEMA = get_codec(PREPROD_ARTIFACT_EVENTS_TOPIC) -class LaunchpadMultiProcessingPool(MultiprocessingPool): - """Extended MultiprocessingPool with maxtasksperchild=1 to ensure clean worker state.""" +def trampoline(function: Callable, log_queue: multiprocessing.Queue, conn: Connection) -> None: + root_logger = logging.getLogger() + root_logger.handlers.clear() + queue_handler = QueueHandler(log_queue) + queue_handler.addFilter(RequestLogFilter()) + root_logger.addHandler(queue_handler) + root_logger.setLevel(logging.DEBUG) + + input_message = conn.recv() + try: + result = function(input_message) + except Exception as e: + conn.send(e) + else: + conn.send(result) + conn.close() + - def maybe_create_pool(self) -> None: - if self._MultiprocessingPool__pool is None: - self._MultiprocessingPool__metrics.increment("arroyo.strategies.run_task_with_multiprocessing.pool.create") - self._MultiprocessingPool__pool = Pool( - self._MultiprocessingPool__num_processes, - initializer=partial(parallel_worker_initializer, self._MultiprocessingPool__initializer), - context=multiprocessing.get_context("spawn"), - maxtasksperchild=1, # why we have this subclass - ) +class Job(Generic[TStrategyPayload, TResult]): + def __init__( + self, + function: Callable, + log_queue: multiprocessing.Queue, + message: Message[TStrategyPayload], + deadline: float = 0, + ) -> None: + ctx = multiprocessing.get_context("forkserver") + ours, theirs = ctx.Pipe(True) + self.__process = ctx.Process(target=trampoline, args=(function, log_queue, theirs)) + self.__process.start() + self.__ours = ours + self.__message = message + self.__deadline = deadline + ours.send(message.payload) + + def poll(self) -> Union[Message[TResult], None]: + if not self.__message: + return None + if self.__deadline and time.time() > self.__deadline: + raise InvalidMessage.from_value(self.__message.value) + if not self.__ours.poll(0): + return None + result = self.__ours.recv() + self.__ours.close() + self.__process.join() + self.__process.close() + self.__process = None + + message = self.__message + self.__message = None + if isinstance(result, Exception): + raise result + else: + return message.replace(result) + def terminate(self) -> None: + if self.__process: + self.__process.terminate() + self.__process = None + self.__message = None -class LaunchpadRunTaskWithMultiprocessing(RunTaskWithMultiprocessing[TStrategyPayload, Any]): - """Tolerates child process exits from maxtasksperchild=1 by ignoring SIGCHLD.""" +class RunTaskWithSubprocess( + ProcessingStrategy[Union[FilteredPayload, TStrategyPayload]], Generic[TStrategyPayload, TResult] +): def __init__( self, - function: Callable[[Message[TStrategyPayload]], Any], - next_step: ProcessingStrategy[FilteredPayload | Any], - max_batch_size: int, - max_batch_time: float, - pool: MultiprocessingPool, - input_block_size: int | None = None, - output_block_size: int | None = None, + function: Callable[[TStrategyPayload], TResult], + next_step: ProcessingStrategy[Union[FilteredPayload, TResult]], + timeout_s: float = 30.0, ) -> None: - super().__init__(function, next_step, max_batch_size, max_batch_time, pool, input_block_size, output_block_size) - # Override SIGCHLD handler - child exits are expected with maxtasksperchild=1 - signal.signal( - signal.SIGCHLD, - lambda signum, frame: logger.debug(f"Worker process exited normally (SIGCHLD {signum})"), - ) + self.__function = function + self.__next_step = next_step + self.__closed = False + self.__timeout = timeout_s + + self.__pending_input = None + self.__job = None + self.__pending_output = None + + ctx = multiprocessing.get_context("forkserver") + self.__log_queue = ctx.Queue() + root_logger = logging.getLogger() + handlers = list(root_logger.handlers) if root_logger.handlers else [] + self.__queue_listener = QueueListener(self.__log_queue, *handlers, respect_handler_level=True) + self.__queue_listener.start() + + def submit(self, message: Message[Union[FilteredPayload, TStrategyPayload]]) -> None: + if self.__closed: + raise MessageRejected("Strategy is closed") + + if self.__pending_input: + raise MessageRejected("Strategy full") + + self.__pending_input = message + + def poll(self) -> None: + if self.__pending_output: + try: + self.__next_step.submit(self.__pending_output) + except MessageRejected: + pass + else: + self.__pending_output = None + elif self.__job: + assert self.__pending_output is None + try: + result = self.__job.poll() + except: + self.__job = None + raise + else: + if result: + self.__job = None + self.__pending_output = result + elif self.__pending_input: + assert self.__job is None + deadline = time.time() + self.__timeout + self.__job = Job(self.__function, self.__log_queue, self.__pending_input, deadline) + self.__pending_input = None + else: + pass + + self.__next_step.poll() + + def close(self) -> None: + self.__closed = True + + def terminate(self) -> None: + self.__closed = True + self.__queue_listener.stop() + + if self.__job: + self.__job.terminate() + self.__job = None + + self.__pending_input = None + self.__pending_ouput = None + + self.__next_step.terminate() + + def join(self, timeout: float | None = None) -> None: + timeout = 24 * 60 * 60 if timeout is None else timeout + start = time.time() + deadline = start + timeout + + while time.time() < deadline: + self.poll() + if not self.__pending_output and not self.__pending_input and not self.__job: + break + time.sleep(0) + + remaining = deadline - time.time() + self.__queue_listener.stop() -def process_kafka_message_with_service(msg: Message[KafkaPayload]) -> Any: + self.__next_step.close() + self.__next_step.join(remaining) + + +def process_kafka_message_with_service(payload: KafkaPayload) -> Any: """Process a Kafka message using the actual service logic in a worker process.""" try: - decoded = PREPROD_ARTIFACT_SCHEMA.decode(msg.payload.value) + decoded = PREPROD_ARTIFACT_SCHEMA.decode(payload.value) ArtifactProcessor.process_message(decoded) return decoded # type: ignore[no-any-return] except Exception as e: @@ -125,12 +247,9 @@ def create_kafka_consumer() -> LaunchpadKafkaConsumer: arroyo_consumer = ArroyoKafkaConsumer(consumer_config) healthcheck_path = config.healthcheck_file + assert healthcheck_path - strategy_factory = LaunchpadStrategyFactory( - concurrency=config.concurrency, - max_pending_futures=config.max_pending_futures, - healthcheck_file=healthcheck_path, - ) + strategy_factory = LaunchpadStrategyFactory(healthcheck_path) topics = [Topic(topic) for topic in config.topics] topic = topics[0] if topics else Topic("default") @@ -140,24 +259,17 @@ def create_kafka_consumer() -> LaunchpadKafkaConsumer: processor_factory=strategy_factory, join_timeout=config.join_timeout_seconds, # Drop in-flight work during rebalance before Kafka times out ) - return LaunchpadKafkaConsumer(processor, healthcheck_path, strategy_factory) + return LaunchpadKafkaConsumer(processor, healthcheck_path) class LaunchpadKafkaConsumer: processor: StreamProcessor[KafkaPayload] healthcheck_path: str | None - strategy_factory: LaunchpadStrategyFactory _running: bool - def __init__( - self, - processor: StreamProcessor[KafkaPayload], - healthcheck_path: str | None, - strategy_factory: LaunchpadStrategyFactory, - ): + def __init__(self, processor: StreamProcessor[KafkaPayload], healthcheck_path: str): self.processor = processor self.healthcheck_path = healthcheck_path - self.strategy_factory = strategy_factory self._running = False def run(self): @@ -169,30 +281,12 @@ def run(self): self.processor.run() finally: self._running = False - try: - os.remove(self.healthcheck_path) - logger.info(f"Removed healthcheck file: {self.healthcheck_path}") - except FileNotFoundError: - pass - - # Clean up multiprocessing pool - try: - self.strategy_factory.close() - logger.debug("Closed multiprocessing pool") - except Exception: - logger.exception("Error closing multiprocessing pool") def stop(self): """Signal shutdown to the processor.""" logger.info(f"{self} stop commanded") self.processor.signal_shutdown() - # Kill all multiprocessing worker children (development only) - environment = os.getenv("LAUNCHPAD_ENV", "development").lower() - if environment == "development": - for child in multiprocessing.active_children(): - child.terminate() - def is_healthy(self) -> bool: try: mtime = os.path.getmtime(self.healthcheck_path) @@ -204,78 +298,28 @@ def is_healthy(self) -> bool: class LaunchpadStrategyFactory(ProcessingStrategyFactory[KafkaPayload]): - """Factory for creating the processing strategy chain.""" - - def __init__( - self, - concurrency: int, - max_pending_futures: int, - healthcheck_file: str | None = None, - ) -> None: - self._log_queue: multiprocessing.Queue[Any] = multiprocessing.Manager().Queue(-1) - self._queue_listener = self._setup_queue_listener() - self._queue_listener.start() - - initializer_with_queue = partial(self._initialize_worker_logging, self._log_queue) - - self._pool = LaunchpadMultiProcessingPool( - num_processes=concurrency, - initializer=initializer_with_queue, - ) - self.concurrency = concurrency - self.max_pending_futures = max_pending_futures - self.healthcheck_file = healthcheck_file - - def _setup_queue_listener(self) -> QueueListener: - """Set up listener in main process to handle logs from workers.""" - root_logger = logging.getLogger() - handlers = list(root_logger.handlers) if root_logger.handlers else [] - - return QueueListener(self._log_queue, *handlers, respect_handler_level=True) - - @staticmethod - def _initialize_worker_logging(log_queue: multiprocessing.Queue[Any]) -> None: - """Initialize logging in worker process to send logs to queue. - - With multiprocessing spawn context, subprocesses don't inherit - parent's stdout/stderr. We use a queue to send log records to - the main process which writes them to stdout for Docker/GCP. - """ - root_logger = logging.getLogger() - root_logger.handlers.clear() - - queue_handler = QueueHandler(log_queue) - queue_handler.addFilter(RequestLogFilter()) - - root_logger.addHandler(queue_handler) - root_logger.setLevel(logging.DEBUG) + def __init__(self, healthcheck_path: str) -> None: + assert healthcheck_path + self.healthcheck_path = healthcheck_path def create_with_partitions( self, commit: Commit, partitions: Mapping[Partition, int], ) -> ProcessingStrategy[KafkaPayload]: - """Create the processing strategy chain.""" - next_step: ProcessingStrategy[Any] = CommitOffsets(commit) - assert self.healthcheck_file - next_step = Healthcheck(self.healthcheck_file, next_step) - - strategy = LaunchpadRunTaskWithMultiprocessing( - process_kafka_message_with_service, - next_step=next_step, - max_batch_size=1, # Process immediately, subject to be re-tuned - max_batch_time=1, # Process after 1 second max, subject to be re-tuned - pool=self._pool, - input_block_size=None, - output_block_size=None, - ) + do_commit = CommitOffsets(commit) + do_health_check = Healthcheck(self.healthcheck_path, next_step=do_commit) + do_task = RunTaskWithSubprocess(process_kafka_message_with_service, next_step=do_health_check, timeout_s=60 * 5) - return strategy + return do_task - def close(self) -> None: - """Clean up the multiprocessing pool and logging queue.""" - self._pool.close() - self._queue_listener.stop() + def shutdown(self) -> None: + try: + os.remove(self.healthcheck_path) + except FileNotFoundError: + logger.error(f"Failed to remove healthcheck file: {self.healthcheck_path}") + else: + logger.info(f"Removed healthcheck file: {self.healthcheck_path}") @dataclass diff --git a/tests/integration/test_kafka.py b/tests/integration/test_kafka.py new file mode 100644 index 00000000..d37cbe49 --- /dev/null +++ b/tests/integration/test_kafka.py @@ -0,0 +1,211 @@ +import multiprocessing +import string +import time + +from datetime import datetime +from unittest.mock import Mock + +import pytest + +from arroyo.backends.kafka import KafkaPayload +from arroyo.dlq import InvalidMessage +from arroyo.processing.strategies.abstract import MessageRejected +from arroyo.types import BrokerValue, Message, Partition, Topic + +from launchpad.kafka import Job, RunTaskWithSubprocess + +topic = Topic("topic") +partition = Partition(topic, 0) + + +def make_value(payload): + return BrokerValue(payload, Partition, 0, datetime.now()) + + +def run_multiply(payload: KafkaPayload) -> KafkaPayload: + return KafkaPayload(payload.key, payload.value * 2, payload.headers) + + +def run_uppercase(payload: KafkaPayload) -> KafkaPayload: + return KafkaPayload(payload.key, payload.value.upper(), payload.headers) + + +def run_raise(x: Message[KafkaPayload]) -> KafkaPayload: + raise ValueError("Function failed intentionally") + + +def run_sleep(_: Message[KafkaPayload]) -> KafkaPayload: + time.sleep(1000) + + +def test_successful_function_execution() -> None: + next_step = Mock() + + strategy = RunTaskWithSubprocess(run_multiply, next_step) + + input_payload = KafkaPayload(None, b"hello", []) + strategy.submit(Message(make_value(input_payload))) + + poll_count = 0 + while next_step.submit.call_count == 0: + strategy.poll() + poll_count += 1 + time.sleep(0) + + next_step.submit.assert_called_once() + + message = next_step.submit.call_args[0][0] + + assert message.payload.value == b"hellohello" + assert message.payload.key is None + assert message.payload.headers == [] + + assert next_step.poll.call_count == poll_count + + strategy.close() + strategy.join() + + +def test_real_timeout() -> None: + next_step = Mock() + + strategy = RunTaskWithSubprocess(run_multiply, next_step, timeout_s=0.1) + + input_payload = KafkaPayload(None, b"hello", []) + strategy.submit(Message(make_value(input_payload))) + + with pytest.raises(InvalidMessage): + while next_step.submit.call_count == 0: + strategy.poll() + time.sleep(0) + + strategy.close() + strategy.join() + + +def test_function_exception_propagation() -> None: + next_step = Mock() + + strategy = RunTaskWithSubprocess( + run_raise, + next_step, + ) + + input_payload = KafkaPayload(None, b"test", []) + strategy.submit(Message(make_value(input_payload))) + + with pytest.raises(ValueError, match=r".*Function failed intentionally.*"): + while True: + strategy.poll() + time.sleep(0) + + next_step.submit.assert_not_called() + + strategy.close() + strategy.join() + + +def test_can_be_terminated() -> None: + next_step = Mock() + + strategy = RunTaskWithSubprocess( + run_raise, + next_step, + ) + + input_payload = KafkaPayload(None, b"test", []) + strategy.submit(Message(make_value(input_payload))) + strategy.poll() + strategy.poll() + strategy.terminate() + + +def test_applies_back_pressure() -> None: + next_step = Mock() + + a = Message(make_value(KafkaPayload(None, b"a", []))) + b = Message(make_value(KafkaPayload(None, b"b", []))) + + strategy = RunTaskWithSubprocess(run_multiply, next_step) + strategy.submit(a) + with pytest.raises(MessageRejected): + strategy.submit(b) + + strategy.close() + strategy.join() + + +def test_submit_does_no_work() -> None: + next_step = Mock() + + a = Message(make_value(KafkaPayload(None, b"a", []))) + b = Message(make_value(KafkaPayload(None, b"b", []))) + c = Message(make_value(KafkaPayload(None, b"c", []))) + + strategy = RunTaskWithSubprocess(run_multiply, next_step) + strategy.submit(a) + + # This poll() starts the task for a: + strategy.poll() + # ...freeing a slot to submit b: + strategy.submit(b) + # ...but submiting again will reject: + with pytest.raises(MessageRejected): + strategy.submit(c) + + strategy.close() + strategy.join() + + +def test_many() -> None: + next_step = Mock() + + alphabet = string.ascii_lowercase[:5] + + queue = list(alphabet)[::-1] + + strategy = RunTaskWithSubprocess(run_uppercase, next_step) + + poll_count = 0 + while next_step.submit.call_count != len(alphabet): + if queue: + letter = queue[-1] + message = Message(make_value(KafkaPayload(None, letter, []))) + try: + strategy.submit(message) + except MessageRejected: + pass + else: + queue.pop() + + poll_count += 1 + strategy.poll() + time.sleep(0.01) + + assert next_step.poll.call_count == poll_count + actual = "".join(args[0][0].payload.value for args in next_step.submit.call_args_list) + + assert actual == alphabet.upper() + + strategy.close() + strategy.join() + + +def return_hello_world(payload: KafkaPayload): + return KafkaPayload(payload.key, "Hello, world!", payload.headers) + + +def test_benchmark_job(benchmark): + log_queue = multiprocessing.Queue() + + def do_job(): + job = Job(return_hello_world, log_queue, Message(make_value(KafkaPayload(None, None, [])))) + while True: + r = job.poll() + if r is not None: + return r + time.sleep(0.01) + + final = benchmark(do_job) + + assert final.payload.value == "Hello, world!"