Skip to content

Commit 197e7e9

Browse files
committed
feat(pd_router): implement dual routing strategies for pd router
Signed-off-by: CYJiang <[email protected]>
1 parent b67d47a commit 197e7e9

File tree

2 files changed

+316
-48
lines changed

2 files changed

+316
-48
lines changed

pkg/plugins/gateway/algorithms/pd_disaggregation.go

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,24 +56,27 @@ const (
5656

5757
defaultMaxRequest float64 = 32
5858
defaultMaxTokenThroughputDiff float64 = 2048
59+
defaultRemotePrefillThreshold = 512
5960
)
6061

6162
var (
6263
prefillRequestTimeout int = utils.LoadEnvInt("AIBRIX_PREFILL_REQUEST_TIMEOUT", defaultPrefillRequestTimeout)
6364
aibrixDecodeMaxRequest float64 = utils.LoadEnvFloat("AIBRIX_DECODE_MAX_REQUEST", defaultMaxRequest)
6465
aibrixDecodeMaxThroughputDiff float64 = utils.LoadEnvFloat("AIBRIX_DECODE_MAX_THROUGHPUT", defaultMaxTokenThroughputDiff)
66+
remotePrefillThreshold = utils.LoadEnvInt("AIBRIX_REMOTE_PREFILL_THRESHOLD", defaultRemotePrefillThreshold)
6567
)
6668

6769
func init() {
6870
Register(RouterPD, NewPDRouter)
6971
}
7072

7173
type pdRouter struct {
72-
cache cache.Cache
73-
tokenizer tokenizer.Tokenizer
74-
prefixCacheIndexer *prefixcacheindexer.PrefixHashTable
75-
prefillRequestTracker *PrefillRequestTracker
76-
httpClient *http.Client
74+
cache cache.Cache
75+
tokenizer tokenizer.Tokenizer
76+
prefixCacheIndexer *prefixcacheindexer.PrefixHashTable
77+
prefillRequestTracker *PrefillRequestTracker
78+
httpClient *http.Client
79+
remotePrefillThreshold int
7780
}
7881

7982
// PrefillRequestTracker manages prefill-specific request counts
@@ -110,11 +113,12 @@ func NewPDRouter() (types.Router, error) {
110113
}
111114

112115
return pdRouter{
113-
cache: c,
114-
tokenizer: tokenizerObj,
115-
prefixCacheIndexer: prefixcacheindexer.NewPrefixHashTable(),
116-
prefillRequestTracker: NewPrefillRequestTracker(),
117-
httpClient: httpClient,
116+
cache: c,
117+
tokenizer: tokenizerObj,
118+
prefixCacheIndexer: prefixcacheindexer.NewPrefixHashTable(),
119+
prefillRequestTracker: NewPrefillRequestTracker(),
120+
httpClient: httpClient,
121+
remotePrefillThreshold: remotePrefillThreshold,
118122
}, nil
119123
}
120124

@@ -126,13 +130,89 @@ func NewPrefillRequestTracker() *PrefillRequestTracker {
126130
}
127131
}
128132

133+
// Route determines which pod should handle the incoming LLM request.
134+
// It supports two routing strategies:
135+
//
136+
// 1. **Short-Prompt Optimization (Decode-Only)**:
137+
// - If the input prompt length (in tokens) ≤ AIBRIX_SHORT_PROMPT_THRESHOLD (default value = 2048 tokens),
138+
// the request is routed directly to a decode pod.
139+
// - The decode pod performs *both* prefill (KV cache computation) and decoding locally,
140+
// bypassing the remote prefill step entirely.
141+
// - This reduces latency for short prompts by eliminating inter-pod communication.
142+
//
143+
// 2. **Standard Two-Stage Pipeline (Prefill+Decode)**:
144+
// - For longer prompts, the system uses the traditional split architecture:
145+
// a) A prefill pod computes the KV cache.
146+
// b) The result is sent to a decode pod for autoregressive generation.
147+
//
148+
// Route decision flow:
149+
/*
150+
┌───────────────────────┐
151+
│ Client Request │
152+
│ (e.g., "Hello!") │
153+
└──────────┬────────────┘
154+
155+
156+
┌───────────────────────┐
157+
│ pdRouter.Route() │
158+
│ (Tokenize & Decide) │
159+
└──────────┬────────────┘
160+
161+
┌──────────┴────────────┬───────────────────────────────┐
162+
│ │ │
163+
│ Is token count │ Yes │ No
164+
│ ≤ threshold? ▼ ▼
165+
│ ┌─────────────────────┐ ┌─────────────────────┐
166+
│ │ Decode Pod │ │ Prefill Pod │
167+
│ │ (Local Prefill + │ │ (Remote Prefill, │
168+
│ │ Decode in one pod) │ │ KV Cache Compute) │
169+
│ └──────────┬──────────┘ └──────────┬──────────┘
170+
│ │ │
171+
│ │ ▼
172+
│ │ ┌─────────────────────┐
173+
│ │ │ Decode Pod │
174+
│ │ │ (Generation Only) │
175+
│ │ └──────────┬──────────┘
176+
│ │ │
177+
└────────────────────────┼──────────────────────────────┘
178+
179+
180+
┌───────────────────────┐
181+
│ Response to │
182+
│ Client │
183+
└───────────────────────┘
184+
*/
129185
func (r pdRouter) Route(ctx *types.RoutingContext, readyPodList types.PodList) (string, error) {
130186
// Validate engine consistency across all prefill pods
131187
llmEngine, err := validateAndGetLLMEngine(readyPodList.All())
132188
if err != nil {
133189
return "", fmt.Errorf("engine validation failed for request %s: %w", ctx.RequestID, err)
134190
}
135191

192+
// NOTE: Short-prompt optimization (bypassing remote prefill) is currently ONLY supported for vLLM.
193+
// Other engines (e.g., SGLang) do not support performing prefill+decode in a decode-only pod,
194+
// so they must always go through the full prefill → decode pipeline.
195+
if r.remotePrefillThreshold > 0 && llmEngine == VLLMEngine {
196+
tokens, err := r.tokenizer.TokenizeInputText(ctx.Message)
197+
if err != nil {
198+
klog.Warningf("Tokenization for short-prompt check failed, falling back to standard routing: %v", err)
199+
} else if len(tokens) <= r.remotePrefillThreshold {
200+
klog.InfoS("Short prompt detected: bypassing remote prefill",
201+
"request_id", ctx.RequestID,
202+
"token_count", len(tokens),
203+
"threshold", r.remotePrefillThreshold)
204+
// short prompt optimization (Decode-Only Path)
205+
_, decodePod, err := r.filterPrefillDecodePods(ctx, readyPodList.All())
206+
if err != nil {
207+
klog.Warning("Failed to select decode pod for direct inference; falling back to prefill-decode flow",
208+
"request_id", ctx.RequestID, "error", err)
209+
} else if decodePod != nil {
210+
ctx.SetTargetPod(decodePod)
211+
return ctx.TargetAddress(), nil
212+
}
213+
}
214+
}
215+
136216
prefillPod, decodePod, err := r.filterPrefillDecodePods(ctx, readyPodList.All())
137217
if err != nil {
138218
return "", fmt.Errorf("failed to filter prefill/decode pods for request %s: %w", ctx.RequestID, err)

0 commit comments

Comments
 (0)