@@ -56,24 +56,27 @@ const (
5656
5757 defaultMaxRequest float64 = 32
5858 defaultMaxTokenThroughputDiff float64 = 2048
59+ defaultRemotePrefillThreshold = 512
5960)
6061
6162var (
6263 prefillRequestTimeout int = utils .LoadEnvInt ("AIBRIX_PREFILL_REQUEST_TIMEOUT" , defaultPrefillRequestTimeout )
6364 aibrixDecodeMaxRequest float64 = utils .LoadEnvFloat ("AIBRIX_DECODE_MAX_REQUEST" , defaultMaxRequest )
6465 aibrixDecodeMaxThroughputDiff float64 = utils .LoadEnvFloat ("AIBRIX_DECODE_MAX_THROUGHPUT" , defaultMaxTokenThroughputDiff )
66+ remotePrefillThreshold = utils .LoadEnvInt ("AIBRIX_REMOTE_PREFILL_THRESHOLD" , defaultRemotePrefillThreshold )
6567)
6668
6769func init () {
6870 Register (RouterPD , NewPDRouter )
6971}
7072
7173type pdRouter struct {
72- cache cache.Cache
73- tokenizer tokenizer.Tokenizer
74- prefixCacheIndexer * prefixcacheindexer.PrefixHashTable
75- prefillRequestTracker * PrefillRequestTracker
76- httpClient * http.Client
74+ cache cache.Cache
75+ tokenizer tokenizer.Tokenizer
76+ prefixCacheIndexer * prefixcacheindexer.PrefixHashTable
77+ prefillRequestTracker * PrefillRequestTracker
78+ httpClient * http.Client
79+ remotePrefillThreshold int
7780}
7881
7982// PrefillRequestTracker manages prefill-specific request counts
@@ -110,11 +113,12 @@ func NewPDRouter() (types.Router, error) {
110113 }
111114
112115 return pdRouter {
113- cache : c ,
114- tokenizer : tokenizerObj ,
115- prefixCacheIndexer : prefixcacheindexer .NewPrefixHashTable (),
116- prefillRequestTracker : NewPrefillRequestTracker (),
117- httpClient : httpClient ,
116+ cache : c ,
117+ tokenizer : tokenizerObj ,
118+ prefixCacheIndexer : prefixcacheindexer .NewPrefixHashTable (),
119+ prefillRequestTracker : NewPrefillRequestTracker (),
120+ httpClient : httpClient ,
121+ remotePrefillThreshold : remotePrefillThreshold ,
118122 }, nil
119123}
120124
@@ -126,13 +130,89 @@ func NewPrefillRequestTracker() *PrefillRequestTracker {
126130 }
127131}
128132
133+ // Route determines which pod should handle the incoming LLM request.
134+ // It supports two routing strategies:
135+ //
136+ // 1. **Short-Prompt Optimization (Decode-Only)**:
137+ // - If the input prompt length (in tokens) ≤ AIBRIX_SHORT_PROMPT_THRESHOLD (default value = 2048 tokens),
138+ // the request is routed directly to a decode pod.
139+ // - The decode pod performs *both* prefill (KV cache computation) and decoding locally,
140+ // bypassing the remote prefill step entirely.
141+ // - This reduces latency for short prompts by eliminating inter-pod communication.
142+ //
143+ // 2. **Standard Two-Stage Pipeline (Prefill+Decode)**:
144+ // - For longer prompts, the system uses the traditional split architecture:
145+ // a) A prefill pod computes the KV cache.
146+ // b) The result is sent to a decode pod for autoregressive generation.
147+ //
148+ // Route decision flow:
149+ /*
150+ ┌───────────────────────┐
151+ │ Client Request │
152+ │ (e.g., "Hello!") │
153+ └──────────┬────────────┘
154+ │
155+ ▼
156+ ┌───────────────────────┐
157+ │ pdRouter.Route() │
158+ │ (Tokenize & Decide) │
159+ └──────────┬────────────┘
160+ │
161+ ┌──────────┴────────────┬───────────────────────────────┐
162+ │ │ │
163+ │ Is token count │ Yes │ No
164+ │ ≤ threshold? ▼ ▼
165+ │ ┌─────────────────────┐ ┌─────────────────────┐
166+ │ │ Decode Pod │ │ Prefill Pod │
167+ │ │ (Local Prefill + │ │ (Remote Prefill, │
168+ │ │ Decode in one pod) │ │ KV Cache Compute) │
169+ │ └──────────┬──────────┘ └──────────┬──────────┘
170+ │ │ │
171+ │ │ ▼
172+ │ │ ┌─────────────────────┐
173+ │ │ │ Decode Pod │
174+ │ │ │ (Generation Only) │
175+ │ │ └──────────┬──────────┘
176+ │ │ │
177+ └────────────────────────┼──────────────────────────────┘
178+ │
179+ ▼
180+ ┌───────────────────────┐
181+ │ Response to │
182+ │ Client │
183+ └───────────────────────┘
184+ */
129185func (r pdRouter ) Route (ctx * types.RoutingContext , readyPodList types.PodList ) (string , error ) {
130186 // Validate engine consistency across all prefill pods
131187 llmEngine , err := validateAndGetLLMEngine (readyPodList .All ())
132188 if err != nil {
133189 return "" , fmt .Errorf ("engine validation failed for request %s: %w" , ctx .RequestID , err )
134190 }
135191
192+ // NOTE: Short-prompt optimization (bypassing remote prefill) is currently ONLY supported for vLLM.
193+ // Other engines (e.g., SGLang) do not support performing prefill+decode in a decode-only pod,
194+ // so they must always go through the full prefill → decode pipeline.
195+ if r .remotePrefillThreshold > 0 && llmEngine == VLLMEngine {
196+ tokens , err := r .tokenizer .TokenizeInputText (ctx .Message )
197+ if err != nil {
198+ klog .Warningf ("Tokenization for short-prompt check failed, falling back to standard routing: %v" , err )
199+ } else if len (tokens ) <= r .remotePrefillThreshold {
200+ klog .InfoS ("Short prompt detected: bypassing remote prefill" ,
201+ "request_id" , ctx .RequestID ,
202+ "token_count" , len (tokens ),
203+ "threshold" , r .remotePrefillThreshold )
204+ // short prompt optimization (Decode-Only Path)
205+ _ , decodePod , err := r .filterPrefillDecodePods (ctx , readyPodList .All ())
206+ if err != nil {
207+ klog .Warning ("Failed to select decode pod for direct inference; falling back to prefill-decode flow" ,
208+ "request_id" , ctx .RequestID , "error" , err )
209+ } else if decodePod != nil {
210+ ctx .SetTargetPod (decodePod )
211+ return ctx .TargetAddress (), nil
212+ }
213+ }
214+ }
215+
136216 prefillPod , decodePod , err := r .filterPrefillDecodePods (ctx , readyPodList .All ())
137217 if err != nil {
138218 return "" , fmt .Errorf ("failed to filter prefill/decode pods for request %s: %w" , ctx .RequestID , err )
0 commit comments