@@ -56,12 +56,14 @@ const (
5656
5757 defaultMaxRequest float64 = 32
5858 defaultMaxTokenThroughputDiff float64 = 2048
59+ defaultShortPromptThreshold = 512
5960)
6061
6162var (
6263 prefillRequestTimeout int = utils .LoadEnvInt ("AIBRIX_PREFILL_REQUEST_TIMEOUT" , defaultPrefillRequestTimeout )
6364 aibrixDecodeMaxRequest float64 = utils .LoadEnvFloat ("AIBRIX_DECODE_MAX_REQUEST" , defaultMaxRequest )
6465 aibrixDecodeMaxThroughputDiff float64 = utils .LoadEnvFloat ("AIBRIX_DECODE_MAX_THROUGHPUT" , defaultMaxTokenThroughputDiff )
66+ shortPromptThreshold = utils .LoadEnvInt ("AIBRIX_SHORT_PROMPT_THRESHOLD" , defaultShortPromptThreshold )
6567)
6668
6769func init () {
@@ -74,6 +76,7 @@ type pdRouter struct {
7476 prefixCacheIndexer * prefixcacheindexer.PrefixHashTable
7577 prefillRequestTracker * PrefillRequestTracker
7678 httpClient * http.Client
79+ shortPromptThreshold int
7780}
7881
7982// PrefillRequestTracker manages prefill-specific request counts
@@ -115,6 +118,7 @@ func NewPDRouter() (types.Router, error) {
115118 prefixCacheIndexer : prefixcacheindexer .NewPrefixHashTable (),
116119 prefillRequestTracker : NewPrefillRequestTracker (),
117120 httpClient : httpClient ,
121+ shortPromptThreshold : shortPromptThreshold ,
118122 }, nil
119123}
120124
@@ -126,7 +130,79 @@ func NewPrefillRequestTracker() *PrefillRequestTracker {
126130 }
127131}
128132
133+ // Route determines which pod should handle the incoming LLM request.
134+ // It supports two routing strategies:
135+ //
136+ // 1. **Short-Prompt Optimization (Decode-Only)**:
137+ // - If the input prompt length (in tokens) ≤ AIBRIX_SHORT_PROMPT_THRESHOLD (default value = 2048 tokens),
138+ // the request is routed directly to a decode pod.
139+ // - The decode pod performs *both* prefill (KV cache computation) and decoding locally,
140+ // bypassing the remote prefill step entirely.
141+ // - This reduces latency for short prompts by eliminating inter-pod communication.
142+ //
143+ // 2. **Standard Two-Stage Pipeline (Prefill+Decode)**:
144+ // - For longer prompts, the system uses the traditional split architecture:
145+ // a) A prefill pod computes the KV cache.
146+ // b) The result is sent to a decode pod for autoregressive generation.
147+ //
148+ // Route decision flow:
149+ /*
150+ ┌───────────────────────┐
151+ │ Client Request │
152+ │ (e.g., "Hello!") │
153+ └──────────┬────────────┘
154+ │
155+ ▼
156+ ┌───────────────────────┐
157+ │ pdRouter.Route() │
158+ │ (Tokenize & Decide) │
159+ └──────────┬────────────┘
160+ │
161+ ┌──────────┴────────────┬───────────────────────────────┐
162+ │ │ │
163+ │ Is token count │ Yes │ No
164+ │ ≤ threshold? ▼ ▼
165+ │ ┌─────────────────────┐ ┌─────────────────────┐
166+ │ │ Decode Pod │ │ Prefill Pod │
167+ │ │ (Local Prefill + │ │ (Remote Prefill, │
168+ │ │ Decode in one pod) │ │ KV Cache Compute) │
169+ │ └──────────┬──────────┘ └──────────┬──────────┘
170+ │ │ │
171+ │ │ ▼
172+ │ │ ┌─────────────────────┐
173+ │ │ │ Decode Pod │
174+ │ │ │ (Generation Only) │
175+ │ │ └──────────┬──────────┘
176+ │ │ │
177+ └────────────────────────┼──────────────────────────────┘
178+ │
179+ ▼
180+ ┌───────────────────────┐
181+ │ Response to │
182+ │ Client │
183+ └───────────────────────┘
184+ */
129185func (r pdRouter ) Route (ctx * types.RoutingContext , readyPodList types.PodList ) (string , error ) {
186+ if r .shortPromptThreshold > 0 {
187+ tokens , err := r .tokenizer .TokenizeInputText (ctx .Message )
188+ if err != nil {
189+ klog .Warningf ("Tokenization for short-prompt check failed, falling back to standard routing: %v" , err )
190+ } else if len (tokens ) <= r .shortPromptThreshold {
191+ klog .InfoS ("Short prompt detected: bypassing remote prefill" ,
192+ "request_id" , ctx .RequestID ,
193+ "token_count" , len (tokens ),
194+ "threshold" , r .shortPromptThreshold )
195+ // short prompt optimization (Decode-Only Path)
196+ _ , decodePod , err := r .filterPrefillDecodePods (ctx , readyPodList .All ())
197+ if err != nil {
198+ klog .Warning ("Failed to select decode pod for direct inference; falling back to prefill-decode flow" ,
199+ "request_id" , ctx .RequestID , "error" , err )
200+ } else if decodePod != nil {
201+ ctx .SetTargetPod (decodePod )
202+ return ctx .TargetAddress (), nil
203+ }
204+ }
205+ }
130206 // Validate engine consistency across all prefill pods
131207 llmEngine , err := validateAndGetLLMEngine (readyPodList .All ())
132208 if err != nil {
0 commit comments