22
33from __future__ import annotations
44
5- import logging
65import multiprocessing
76import os
8- import signal
97import time
108
119from dataclasses import dataclass
12- from functools import partial
13- from logging .handlers import QueueHandler , QueueListener
14- from multiprocessing .pool import Pool
15- from typing import Any , Callable , Mapping
10+ from multiprocessing .connection import Connection
11+ from typing import Any , Callable , Generic , Mapping , TypeVar , Union
1612
17- from arroyo import Message , Topic , configure_metrics
13+ from arroyo import Topic , configure_metrics
1814from arroyo .backends .kafka import KafkaConsumer as ArroyoKafkaConsumer
1915from arroyo .backends .kafka import KafkaPayload
16+ from arroyo .dlq import InvalidMessage
2017from arroyo .processing .processor import StreamProcessor
2118from arroyo .processing .strategies import ProcessingStrategy , ProcessingStrategyFactory
19+ from arroyo .processing .strategies .abstract import MessageRejected
2220from arroyo .processing .strategies .commit import CommitOffsets
2321from arroyo .processing .strategies .healthcheck import Healthcheck
24- from arroyo .processing .strategies .run_task_with_multiprocessing import (
25- MultiprocessingPool ,
26- RunTaskWithMultiprocessing ,
27- parallel_worker_initializer ,
28- )
29- from arroyo .types import Commit , FilteredPayload , Partition , TStrategyPayload
22+ from arroyo .types import Commit , FilteredPayload , Message , Partition , TStrategyPayload
3023from sentry_kafka_schemas import get_codec
3124
3225from launchpad .artifact_processor import ArtifactProcessor
3326from launchpad .constants import HEALTHCHECK_MAX_AGE_SECONDS , PREPROD_ARTIFACT_EVENTS_TOPIC
34- from launchpad .tracing import RequestLogFilter
3527from launchpad .utils .arroyo_metrics import DatadogMetricsBackend
3628from launchpad .utils .logging import get_logger
3729
30+ TResult = TypeVar ("TResult" )
31+
3832logger = get_logger (__name__ )
3933
4034# Schema codec for preprod artifact events
4135PREPROD_ARTIFACT_SCHEMA = get_codec (PREPROD_ARTIFACT_EVENTS_TOPIC )
4236
4337
44- class LaunchpadMultiProcessingPool (MultiprocessingPool ):
45- """Extended MultiprocessingPool with maxtasksperchild=1 to ensure clean worker state."""
46-
47- def maybe_create_pool (self ) -> None :
48- if self ._MultiprocessingPool__pool is None :
49- self ._MultiprocessingPool__metrics .increment ("arroyo.strategies.run_task_with_multiprocessing.pool.create" )
50- self ._MultiprocessingPool__pool = Pool (
51- self ._MultiprocessingPool__num_processes ,
52- initializer = partial (parallel_worker_initializer , self ._MultiprocessingPool__initializer ),
53- context = multiprocessing .get_context ("spawn" ),
54- maxtasksperchild = 1 , # why we have this subclass
55- )
38+ def trampoline (function : Callable , conn : Connection ) -> None :
39+ input_message = conn .recv ()
40+ try :
41+ result = function (input_message )
42+ except Exception as e :
43+ conn .send (e )
44+ else :
45+ conn .send (result )
46+ conn .close ()
47+
48+
49+ class Job (Generic [TStrategyPayload , TResult ]):
50+ def __init__ (self , function : Callable , message : Message [TStrategyPayload ], deadline : float = 0 ) -> None :
51+ ctx = multiprocessing .get_context ("forkserver" )
52+ ours , theirs = ctx .Pipe (True )
53+ self .__process = ctx .Process (target = trampoline , args = (function , theirs ))
54+ self .__process .start ()
55+ self .__ours = ours
56+ self .__message = message
57+ self .__deadline = deadline
58+ ours .send (message .payload )
59+
60+ def poll (self ) -> Union [Message [TResult ], None ]:
61+ if not self .__message :
62+ return None
63+ if self .__deadline and time .time () > self .__deadline :
64+ raise InvalidMessage .from_value (self .__message .value )
65+ if not self .__ours .poll (0 ):
66+ return None
67+ result = self .__ours .recv ()
68+ self .__ours .close ()
69+ self .__process .join ()
70+ self .__process .close ()
71+ self .__process = None
72+
73+ message = self .__message
74+ self .__message = None
75+ if isinstance (result , Exception ):
76+ raise result
77+ else :
78+ return message .replace (result )
5679
80+ def terminate (self ) -> None :
81+ if self .__process :
82+ self .__process .terminate ()
83+ self .__process = None
84+ self .__message = None
5785
58- class LaunchpadRunTaskWithMultiprocessing (RunTaskWithMultiprocessing [TStrategyPayload , Any ]):
59- """Tolerates child process exits from maxtasksperchild=1 by ignoring SIGCHLD."""
6086
87+ class RunTaskWithSubprocess (
88+ ProcessingStrategy [Union [FilteredPayload , TStrategyPayload ]], Generic [TStrategyPayload , TResult ]
89+ ):
6190 def __init__ (
6291 self ,
63- function : Callable [[Message [TStrategyPayload ]], Any ],
64- next_step : ProcessingStrategy [FilteredPayload | Any ],
65- max_batch_size : int ,
66- max_batch_time : float ,
67- pool : MultiprocessingPool ,
68- input_block_size : int | None = None ,
69- output_block_size : int | None = None ,
92+ function : Callable [[Message [TStrategyPayload ]], TResult ],
93+ next_step : ProcessingStrategy [Union [FilteredPayload , TResult ]],
94+ timeout_s : float = 30.0 ,
7095 ) -> None :
71- super ().__init__ (function , next_step , max_batch_size , max_batch_time , pool , input_block_size , output_block_size )
72- # Override SIGCHLD handler - child exits are expected with maxtasksperchild=1
73- signal .signal (
74- signal .SIGCHLD ,
75- lambda signum , frame : logger .debug (f"Worker process exited normally (SIGCHLD { signum } )" ),
76- )
96+ self .__function = function
97+ self .__next_step = next_step
98+ self .__closed = False
99+ self .__timeout = timeout_s
100+
101+ self .__pending_input = None
102+ self .__job = None
103+ self .__pending_output = None
104+
105+ def submit (self , message : Message [Union [FilteredPayload , TStrategyPayload ]]) -> None :
106+ if self .__closed :
107+ raise MessageRejected ("Strategy is closed" )
108+
109+ if self .__pending_input :
110+ raise MessageRejected ("Strategy full" )
111+
112+ self .__pending_input = message
113+
114+ def poll (self ) -> None :
115+ if self .__pending_output :
116+ try :
117+ self .__next_step .submit (self .__pending_output )
118+ except MessageRejected :
119+ pass
120+ else :
121+ self .__pending_output = None
122+ elif self .__job :
123+ assert self .__pending_output is None
124+ try :
125+ result = self .__job .poll ()
126+ except :
127+ self .__job = None
128+ raise
129+ else :
130+ if result :
131+ self .__job = None
132+ self .__pending_output = result
133+ elif self .__pending_input :
134+ assert self .__job is None
135+ deadline = time .time () + self .__timeout
136+ self .__job = Job (self .__function , self .__pending_input , deadline )
137+ self .__pending_input = None
138+ else :
139+ pass
140+
141+ self .__next_step .poll ()
142+
143+ def close (self ) -> None :
144+ self .__closed = True
145+
146+ def terminate (self ) -> None :
147+ self .__closed = True
148+
149+ if self .__job :
150+ self .__job .terminate ()
151+ self .__job = None
152+
153+ self .__pending_input = None
154+ self .__pending_ouput = None
155+
156+ self .__next_step .terminate ()
157+
158+ def join (self , timeout : float | None = None ) -> None :
159+ timeout = 24 * 60 * 60 if timeout is None else timeout
160+ start = time .time ()
161+ deadline = start + timeout
162+
163+ while time .time () < deadline :
164+ self .poll ()
165+ if not self .__pending_output and not self .__pending_input and not self .__job :
166+ break
167+ time .sleep (0 )
168+
169+ remaining = deadline - time .time ()
170+
171+ self .__next_step .close ()
172+ self .__next_step .join (remaining )
77173
78174
79175def process_kafka_message_with_service (msg : Message [KafkaPayload ]) -> Any :
@@ -125,12 +221,9 @@ def create_kafka_consumer() -> LaunchpadKafkaConsumer:
125221
126222 arroyo_consumer = ArroyoKafkaConsumer (consumer_config )
127223 healthcheck_path = config .healthcheck_file
224+ assert healthcheck_path
128225
129- strategy_factory = LaunchpadStrategyFactory (
130- concurrency = config .concurrency ,
131- max_pending_futures = config .max_pending_futures ,
132- healthcheck_file = healthcheck_path ,
133- )
226+ strategy_factory = LaunchpadStrategyFactory (healthcheck_path )
134227
135228 topics = [Topic (topic ) for topic in config .topics ]
136229 topic = topics [0 ] if topics else Topic ("default" )
@@ -140,24 +233,17 @@ def create_kafka_consumer() -> LaunchpadKafkaConsumer:
140233 processor_factory = strategy_factory ,
141234 join_timeout = config .join_timeout_seconds , # Drop in-flight work during rebalance before Kafka times out
142235 )
143- return LaunchpadKafkaConsumer (processor , healthcheck_path , strategy_factory )
236+ return LaunchpadKafkaConsumer (processor , healthcheck_path )
144237
145238
146239class LaunchpadKafkaConsumer :
147240 processor : StreamProcessor [KafkaPayload ]
148241 healthcheck_path : str | None
149- strategy_factory : LaunchpadStrategyFactory
150242 _running : bool
151243
152- def __init__ (
153- self ,
154- processor : StreamProcessor [KafkaPayload ],
155- healthcheck_path : str | None ,
156- strategy_factory : LaunchpadStrategyFactory ,
157- ):
244+ def __init__ (self , processor : StreamProcessor [KafkaPayload ], healthcheck_path : str ):
158245 self .processor = processor
159246 self .healthcheck_path = healthcheck_path
160- self .strategy_factory = strategy_factory
161247 self ._running = False
162248
163249 def run (self ):
@@ -169,30 +255,12 @@ def run(self):
169255 self .processor .run ()
170256 finally :
171257 self ._running = False
172- try :
173- os .remove (self .healthcheck_path )
174- logger .info (f"Removed healthcheck file: { self .healthcheck_path } " )
175- except FileNotFoundError :
176- pass
177-
178- # Clean up multiprocessing pool
179- try :
180- self .strategy_factory .close ()
181- logger .debug ("Closed multiprocessing pool" )
182- except Exception :
183- logger .exception ("Error closing multiprocessing pool" )
184258
185259 def stop (self ):
186260 """Signal shutdown to the processor."""
187261 logger .info (f"{ self } stop commanded" )
188262 self .processor .signal_shutdown ()
189263
190- # Kill all multiprocessing worker children (development only)
191- environment = os .getenv ("LAUNCHPAD_ENV" , "development" ).lower ()
192- if environment == "development" :
193- for child in multiprocessing .active_children ():
194- child .terminate ()
195-
196264 def is_healthy (self ) -> bool :
197265 try :
198266 mtime = os .path .getmtime (self .healthcheck_path )
@@ -204,78 +272,28 @@ def is_healthy(self) -> bool:
204272
205273
206274class LaunchpadStrategyFactory (ProcessingStrategyFactory [KafkaPayload ]):
207- """Factory for creating the processing strategy chain."""
208-
209- def __init__ (
210- self ,
211- concurrency : int ,
212- max_pending_futures : int ,
213- healthcheck_file : str | None = None ,
214- ) -> None :
215- self ._log_queue : multiprocessing .Queue [Any ] = multiprocessing .Manager ().Queue (- 1 )
216- self ._queue_listener = self ._setup_queue_listener ()
217- self ._queue_listener .start ()
218-
219- initializer_with_queue = partial (self ._initialize_worker_logging , self ._log_queue )
220-
221- self ._pool = LaunchpadMultiProcessingPool (
222- num_processes = concurrency ,
223- initializer = initializer_with_queue ,
224- )
225- self .concurrency = concurrency
226- self .max_pending_futures = max_pending_futures
227- self .healthcheck_file = healthcheck_file
228-
229- def _setup_queue_listener (self ) -> QueueListener :
230- """Set up listener in main process to handle logs from workers."""
231- root_logger = logging .getLogger ()
232- handlers = list (root_logger .handlers ) if root_logger .handlers else []
233-
234- return QueueListener (self ._log_queue , * handlers , respect_handler_level = True )
235-
236- @staticmethod
237- def _initialize_worker_logging (log_queue : multiprocessing .Queue [Any ]) -> None :
238- """Initialize logging in worker process to send logs to queue.
239-
240- With multiprocessing spawn context, subprocesses don't inherit
241- parent's stdout/stderr. We use a queue to send log records to
242- the main process which writes them to stdout for Docker/GCP.
243- """
244- root_logger = logging .getLogger ()
245- root_logger .handlers .clear ()
246-
247- queue_handler = QueueHandler (log_queue )
248- queue_handler .addFilter (RequestLogFilter ())
249-
250- root_logger .addHandler (queue_handler )
251- root_logger .setLevel (logging .DEBUG )
275+ def __init__ (self , healthcheck_path : str ) -> None :
276+ assert healthcheck_path
277+ self .healthcheck_path = healthcheck_path
252278
253279 def create_with_partitions (
254280 self ,
255281 commit : Commit ,
256282 partitions : Mapping [Partition , int ],
257283 ) -> ProcessingStrategy [KafkaPayload ]:
258- """Create the processing strategy chain."""
259- next_step : ProcessingStrategy [Any ] = CommitOffsets (commit )
260- assert self .healthcheck_file
261- next_step = Healthcheck (self .healthcheck_file , next_step )
262-
263- strategy = LaunchpadRunTaskWithMultiprocessing (
264- process_kafka_message_with_service ,
265- next_step = next_step ,
266- max_batch_size = 1 , # Process immediately, subject to be re-tuned
267- max_batch_time = 1 , # Process after 1 second max, subject to be re-tuned
268- pool = self ._pool ,
269- input_block_size = None ,
270- output_block_size = None ,
271- )
284+ do_commit = CommitOffsets (commit )
285+ do_health_check = Healthcheck (self .healthcheck_file , next_step = do_commit )
286+ do_task = RunTaskWithSubprocess (process_kafka_message_with_service , next_step = do_health_check )
272287
273- return strategy
288+ return do_task
274289
275- def close (self ) -> None :
276- """Clean up the multiprocessing pool and logging queue."""
277- self ._pool .close ()
278- self ._queue_listener .stop ()
290+ def shutdown (self ) -> None :
291+ try :
292+ os .remove (self .healthcheck_path )
293+ except FileNotFoundError :
294+ logger .error (f"Failed to remove healthcheck file: { self .healthcheck_path } " )
295+ else :
296+ logger .info (f"Removed healthcheck file: { self .healthcheck_path } " )
279297
280298
281299@dataclass
0 commit comments