Skip to content

Commit 23cb118

Browse files
committed
Add RunWithSubprocess
1 parent 16fc355 commit 23cb118

File tree

3 files changed

+363
-137
lines changed

3 files changed

+363
-137
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pytest-xdist>=3.8.0
1717
pytest-benchmark>=5.1.0
1818
pytest>=8.4.2
1919
responses>=0.25.8
20+
time-machine>=2.15.0
2021
ruff>=0.12.12
2122
safety>=3.6.1
2223
ty==0.0.1a20

src/launchpad/kafka.py

Lines changed: 155 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -2,78 +2,174 @@
22

33
from __future__ import annotations
44

5-
import logging
65
import multiprocessing
76
import os
8-
import signal
97
import time
108

119
from dataclasses import dataclass
12-
from functools import partial
13-
from logging.handlers import QueueHandler, QueueListener
14-
from multiprocessing.pool import Pool
15-
from typing import Any, Callable, Mapping
10+
from multiprocessing.connection import Connection
11+
from typing import Any, Callable, Generic, Mapping, TypeVar, Union
1612

17-
from arroyo import Message, Topic, configure_metrics
13+
from arroyo import Topic, configure_metrics
1814
from arroyo.backends.kafka import KafkaConsumer as ArroyoKafkaConsumer
1915
from arroyo.backends.kafka import KafkaPayload
16+
from arroyo.dlq import InvalidMessage
2017
from arroyo.processing.processor import StreamProcessor
2118
from arroyo.processing.strategies import ProcessingStrategy, ProcessingStrategyFactory
19+
from arroyo.processing.strategies.abstract import MessageRejected
2220
from arroyo.processing.strategies.commit import CommitOffsets
2321
from arroyo.processing.strategies.healthcheck import Healthcheck
24-
from arroyo.processing.strategies.run_task_with_multiprocessing import (
25-
MultiprocessingPool,
26-
RunTaskWithMultiprocessing,
27-
parallel_worker_initializer,
28-
)
29-
from arroyo.types import Commit, FilteredPayload, Partition, TStrategyPayload
22+
from arroyo.types import Commit, FilteredPayload, Message, Partition, TStrategyPayload
3023
from sentry_kafka_schemas import get_codec
3124

3225
from launchpad.artifact_processor import ArtifactProcessor
3326
from launchpad.constants import HEALTHCHECK_MAX_AGE_SECONDS, PREPROD_ARTIFACT_EVENTS_TOPIC
34-
from launchpad.tracing import RequestLogFilter
3527
from launchpad.utils.arroyo_metrics import DatadogMetricsBackend
3628
from launchpad.utils.logging import get_logger
3729

30+
TResult = TypeVar("TResult")
31+
3832
logger = get_logger(__name__)
3933

4034
# Schema codec for preprod artifact events
4135
PREPROD_ARTIFACT_SCHEMA = get_codec(PREPROD_ARTIFACT_EVENTS_TOPIC)
4236

4337

44-
class LaunchpadMultiProcessingPool(MultiprocessingPool):
45-
"""Extended MultiprocessingPool with maxtasksperchild=1 to ensure clean worker state."""
46-
47-
def maybe_create_pool(self) -> None:
48-
if self._MultiprocessingPool__pool is None:
49-
self._MultiprocessingPool__metrics.increment("arroyo.strategies.run_task_with_multiprocessing.pool.create")
50-
self._MultiprocessingPool__pool = Pool(
51-
self._MultiprocessingPool__num_processes,
52-
initializer=partial(parallel_worker_initializer, self._MultiprocessingPool__initializer),
53-
context=multiprocessing.get_context("spawn"),
54-
maxtasksperchild=1, # why we have this subclass
55-
)
38+
def trampoline(function: Callable, conn: Connection) -> None:
39+
input_message = conn.recv()
40+
try:
41+
result = function(input_message)
42+
except Exception as e:
43+
conn.send(e)
44+
else:
45+
conn.send(result)
46+
conn.close()
47+
48+
49+
class Job(Generic[TStrategyPayload, TResult]):
50+
def __init__(self, function: Callable, message: Message[TStrategyPayload], deadline: float = 0) -> None:
51+
ctx = multiprocessing.get_context("forkserver")
52+
ours, theirs = ctx.Pipe(True)
53+
self.__process = ctx.Process(target=trampoline, args=(function, theirs))
54+
self.__process.start()
55+
self.__ours = ours
56+
self.__message = message
57+
self.__deadline = deadline
58+
ours.send(message.payload)
59+
60+
def poll(self) -> Union[Message[TResult], None]:
61+
if not self.__message:
62+
return None
63+
if self.__deadline and time.time() > self.__deadline:
64+
raise InvalidMessage.from_value(self.__message.value)
65+
if not self.__ours.poll(0):
66+
return None
67+
result = self.__ours.recv()
68+
self.__ours.close()
69+
self.__process.join()
70+
self.__process.close()
71+
self.__process = None
72+
73+
message = self.__message
74+
self.__message = None
75+
if isinstance(result, Exception):
76+
raise result
77+
else:
78+
return message.replace(result)
5679

80+
def terminate(self) -> None:
81+
if self.__process:
82+
self.__process.terminate()
83+
self.__process = None
84+
self.__message = None
5785

58-
class LaunchpadRunTaskWithMultiprocessing(RunTaskWithMultiprocessing[TStrategyPayload, Any]):
59-
"""Tolerates child process exits from maxtasksperchild=1 by ignoring SIGCHLD."""
6086

87+
class RunTaskWithSubprocess(
88+
ProcessingStrategy[Union[FilteredPayload, TStrategyPayload]], Generic[TStrategyPayload, TResult]
89+
):
6190
def __init__(
6291
self,
63-
function: Callable[[Message[TStrategyPayload]], Any],
64-
next_step: ProcessingStrategy[FilteredPayload | Any],
65-
max_batch_size: int,
66-
max_batch_time: float,
67-
pool: MultiprocessingPool,
68-
input_block_size: int | None = None,
69-
output_block_size: int | None = None,
92+
function: Callable[[Message[TStrategyPayload]], TResult],
93+
next_step: ProcessingStrategy[Union[FilteredPayload, TResult]],
94+
timeout_s: float = 30.0,
7095
) -> None:
71-
super().__init__(function, next_step, max_batch_size, max_batch_time, pool, input_block_size, output_block_size)
72-
# Override SIGCHLD handler - child exits are expected with maxtasksperchild=1
73-
signal.signal(
74-
signal.SIGCHLD,
75-
lambda signum, frame: logger.debug(f"Worker process exited normally (SIGCHLD {signum})"),
76-
)
96+
self.__function = function
97+
self.__next_step = next_step
98+
self.__closed = False
99+
self.__timeout = timeout_s
100+
101+
self.__pending_input = None
102+
self.__job = None
103+
self.__pending_output = None
104+
105+
def submit(self, message: Message[Union[FilteredPayload, TStrategyPayload]]) -> None:
106+
if self.__closed:
107+
raise MessageRejected("Strategy is closed")
108+
109+
if self.__pending_input:
110+
raise MessageRejected("Strategy full")
111+
112+
self.__pending_input = message
113+
114+
def poll(self) -> None:
115+
if self.__pending_output:
116+
try:
117+
self.__next_step.submit(self.__pending_output)
118+
except MessageRejected:
119+
pass
120+
else:
121+
self.__pending_output = None
122+
elif self.__job:
123+
assert self.__pending_output is None
124+
try:
125+
result = self.__job.poll()
126+
except:
127+
self.__job = None
128+
raise
129+
else:
130+
if result:
131+
self.__job = None
132+
self.__pending_output = result
133+
elif self.__pending_input:
134+
assert self.__job is None
135+
deadline = time.time() + self.__timeout
136+
self.__job = Job(self.__function, self.__pending_input, deadline)
137+
self.__pending_input = None
138+
else:
139+
pass
140+
141+
self.__next_step.poll()
142+
143+
def close(self) -> None:
144+
self.__closed = True
145+
146+
def terminate(self) -> None:
147+
self.__closed = True
148+
149+
if self.__job:
150+
self.__job.terminate()
151+
self.__job = None
152+
153+
self.__pending_input = None
154+
self.__pending_ouput = None
155+
156+
self.__next_step.terminate()
157+
158+
def join(self, timeout: float | None = None) -> None:
159+
timeout = 24 * 60 * 60 if timeout is None else timeout
160+
start = time.time()
161+
deadline = start + timeout
162+
163+
while time.time() < deadline:
164+
self.poll()
165+
if not self.__pending_output and not self.__pending_input and not self.__job:
166+
break
167+
time.sleep(0)
168+
169+
remaining = deadline - time.time()
170+
171+
self.__next_step.close()
172+
self.__next_step.join(remaining)
77173

78174

79175
def process_kafka_message_with_service(msg: Message[KafkaPayload]) -> Any:
@@ -125,12 +221,9 @@ def create_kafka_consumer() -> LaunchpadKafkaConsumer:
125221

126222
arroyo_consumer = ArroyoKafkaConsumer(consumer_config)
127223
healthcheck_path = config.healthcheck_file
224+
assert healthcheck_path
128225

129-
strategy_factory = LaunchpadStrategyFactory(
130-
concurrency=config.concurrency,
131-
max_pending_futures=config.max_pending_futures,
132-
healthcheck_file=healthcheck_path,
133-
)
226+
strategy_factory = LaunchpadStrategyFactory(healthcheck_path)
134227

135228
topics = [Topic(topic) for topic in config.topics]
136229
topic = topics[0] if topics else Topic("default")
@@ -140,24 +233,17 @@ def create_kafka_consumer() -> LaunchpadKafkaConsumer:
140233
processor_factory=strategy_factory,
141234
join_timeout=config.join_timeout_seconds, # Drop in-flight work during rebalance before Kafka times out
142235
)
143-
return LaunchpadKafkaConsumer(processor, healthcheck_path, strategy_factory)
236+
return LaunchpadKafkaConsumer(processor, healthcheck_path)
144237

145238

146239
class LaunchpadKafkaConsumer:
147240
processor: StreamProcessor[KafkaPayload]
148241
healthcheck_path: str | None
149-
strategy_factory: LaunchpadStrategyFactory
150242
_running: bool
151243

152-
def __init__(
153-
self,
154-
processor: StreamProcessor[KafkaPayload],
155-
healthcheck_path: str | None,
156-
strategy_factory: LaunchpadStrategyFactory,
157-
):
244+
def __init__(self, processor: StreamProcessor[KafkaPayload], healthcheck_path: str):
158245
self.processor = processor
159246
self.healthcheck_path = healthcheck_path
160-
self.strategy_factory = strategy_factory
161247
self._running = False
162248

163249
def run(self):
@@ -169,30 +255,12 @@ def run(self):
169255
self.processor.run()
170256
finally:
171257
self._running = False
172-
try:
173-
os.remove(self.healthcheck_path)
174-
logger.info(f"Removed healthcheck file: {self.healthcheck_path}")
175-
except FileNotFoundError:
176-
pass
177-
178-
# Clean up multiprocessing pool
179-
try:
180-
self.strategy_factory.close()
181-
logger.debug("Closed multiprocessing pool")
182-
except Exception:
183-
logger.exception("Error closing multiprocessing pool")
184258

185259
def stop(self):
186260
"""Signal shutdown to the processor."""
187261
logger.info(f"{self} stop commanded")
188262
self.processor.signal_shutdown()
189263

190-
# Kill all multiprocessing worker children (development only)
191-
environment = os.getenv("LAUNCHPAD_ENV", "development").lower()
192-
if environment == "development":
193-
for child in multiprocessing.active_children():
194-
child.terminate()
195-
196264
def is_healthy(self) -> bool:
197265
try:
198266
mtime = os.path.getmtime(self.healthcheck_path)
@@ -204,78 +272,28 @@ def is_healthy(self) -> bool:
204272

205273

206274
class LaunchpadStrategyFactory(ProcessingStrategyFactory[KafkaPayload]):
207-
"""Factory for creating the processing strategy chain."""
208-
209-
def __init__(
210-
self,
211-
concurrency: int,
212-
max_pending_futures: int,
213-
healthcheck_file: str | None = None,
214-
) -> None:
215-
self._log_queue: multiprocessing.Queue[Any] = multiprocessing.Manager().Queue(-1)
216-
self._queue_listener = self._setup_queue_listener()
217-
self._queue_listener.start()
218-
219-
initializer_with_queue = partial(self._initialize_worker_logging, self._log_queue)
220-
221-
self._pool = LaunchpadMultiProcessingPool(
222-
num_processes=concurrency,
223-
initializer=initializer_with_queue,
224-
)
225-
self.concurrency = concurrency
226-
self.max_pending_futures = max_pending_futures
227-
self.healthcheck_file = healthcheck_file
228-
229-
def _setup_queue_listener(self) -> QueueListener:
230-
"""Set up listener in main process to handle logs from workers."""
231-
root_logger = logging.getLogger()
232-
handlers = list(root_logger.handlers) if root_logger.handlers else []
233-
234-
return QueueListener(self._log_queue, *handlers, respect_handler_level=True)
235-
236-
@staticmethod
237-
def _initialize_worker_logging(log_queue: multiprocessing.Queue[Any]) -> None:
238-
"""Initialize logging in worker process to send logs to queue.
239-
240-
With multiprocessing spawn context, subprocesses don't inherit
241-
parent's stdout/stderr. We use a queue to send log records to
242-
the main process which writes them to stdout for Docker/GCP.
243-
"""
244-
root_logger = logging.getLogger()
245-
root_logger.handlers.clear()
246-
247-
queue_handler = QueueHandler(log_queue)
248-
queue_handler.addFilter(RequestLogFilter())
249-
250-
root_logger.addHandler(queue_handler)
251-
root_logger.setLevel(logging.DEBUG)
275+
def __init__(self, healthcheck_path: str) -> None:
276+
assert healthcheck_path
277+
self.healthcheck_path = healthcheck_path
252278

253279
def create_with_partitions(
254280
self,
255281
commit: Commit,
256282
partitions: Mapping[Partition, int],
257283
) -> ProcessingStrategy[KafkaPayload]:
258-
"""Create the processing strategy chain."""
259-
next_step: ProcessingStrategy[Any] = CommitOffsets(commit)
260-
assert self.healthcheck_file
261-
next_step = Healthcheck(self.healthcheck_file, next_step)
262-
263-
strategy = LaunchpadRunTaskWithMultiprocessing(
264-
process_kafka_message_with_service,
265-
next_step=next_step,
266-
max_batch_size=1, # Process immediately, subject to be re-tuned
267-
max_batch_time=1, # Process after 1 second max, subject to be re-tuned
268-
pool=self._pool,
269-
input_block_size=None,
270-
output_block_size=None,
271-
)
284+
do_commit = CommitOffsets(commit)
285+
do_health_check = Healthcheck(self.healthcheck_file, next_step=do_commit)
286+
do_task = RunTaskWithSubprocess(process_kafka_message_with_service, next_step=do_health_check)
272287

273-
return strategy
288+
return do_task
274289

275-
def close(self) -> None:
276-
"""Clean up the multiprocessing pool and logging queue."""
277-
self._pool.close()
278-
self._queue_listener.stop()
290+
def shutdown(self) -> None:
291+
try:
292+
os.remove(self.healthcheck_path)
293+
except FileNotFoundError:
294+
logger.error(f"Failed to remove healthcheck file: {self.healthcheck_path}")
295+
else:
296+
logger.info(f"Removed healthcheck file: {self.healthcheck_path}")
279297

280298

281299
@dataclass

0 commit comments

Comments
 (0)