-
Notifications
You must be signed in to change notification settings - Fork 494
feat(pd_router): implement prompt-length-aware routing for pd router #1800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -56,24 +56,27 @@ const ( | |||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| defaultMaxRequest float64 = 32 | ||||||||||||||||||||||||||||||||||
| defaultMaxTokenThroughputDiff float64 = 2048 | ||||||||||||||||||||||||||||||||||
| defaultRemotePrefillThreshold = 512 | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| var ( | ||||||||||||||||||||||||||||||||||
| prefillRequestTimeout int = utils.LoadEnvInt("AIBRIX_PREFILL_REQUEST_TIMEOUT", defaultPrefillRequestTimeout) | ||||||||||||||||||||||||||||||||||
| aibrixDecodeMaxRequest float64 = utils.LoadEnvFloat("AIBRIX_DECODE_MAX_REQUEST", defaultMaxRequest) | ||||||||||||||||||||||||||||||||||
| aibrixDecodeMaxThroughputDiff float64 = utils.LoadEnvFloat("AIBRIX_DECODE_MAX_THROUGHPUT", defaultMaxTokenThroughputDiff) | ||||||||||||||||||||||||||||||||||
| remotePrefillThreshold = utils.LoadEnvInt("AIBRIX_REMOTE_PREFILL_THRESHOLD", defaultRemotePrefillThreshold) | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| func init() { | ||||||||||||||||||||||||||||||||||
| Register(RouterPD, NewPDRouter) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| type pdRouter struct { | ||||||||||||||||||||||||||||||||||
| cache cache.Cache | ||||||||||||||||||||||||||||||||||
| tokenizer tokenizer.Tokenizer | ||||||||||||||||||||||||||||||||||
| prefixCacheIndexer *prefixcacheindexer.PrefixHashTable | ||||||||||||||||||||||||||||||||||
| prefillRequestTracker *PrefillRequestTracker | ||||||||||||||||||||||||||||||||||
| httpClient *http.Client | ||||||||||||||||||||||||||||||||||
| cache cache.Cache | ||||||||||||||||||||||||||||||||||
| tokenizer tokenizer.Tokenizer | ||||||||||||||||||||||||||||||||||
| prefixCacheIndexer *prefixcacheindexer.PrefixHashTable | ||||||||||||||||||||||||||||||||||
| prefillRequestTracker *PrefillRequestTracker | ||||||||||||||||||||||||||||||||||
| httpClient *http.Client | ||||||||||||||||||||||||||||||||||
| remotePrefillThreshold int | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // PrefillRequestTracker manages prefill-specific request counts | ||||||||||||||||||||||||||||||||||
|
|
@@ -110,11 +113,12 @@ func NewPDRouter() (types.Router, error) { | |||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return pdRouter{ | ||||||||||||||||||||||||||||||||||
| cache: c, | ||||||||||||||||||||||||||||||||||
| tokenizer: tokenizerObj, | ||||||||||||||||||||||||||||||||||
| prefixCacheIndexer: prefixcacheindexer.NewPrefixHashTable(), | ||||||||||||||||||||||||||||||||||
| prefillRequestTracker: NewPrefillRequestTracker(), | ||||||||||||||||||||||||||||||||||
| httpClient: httpClient, | ||||||||||||||||||||||||||||||||||
| cache: c, | ||||||||||||||||||||||||||||||||||
| tokenizer: tokenizerObj, | ||||||||||||||||||||||||||||||||||
| prefixCacheIndexer: prefixcacheindexer.NewPrefixHashTable(), | ||||||||||||||||||||||||||||||||||
| prefillRequestTracker: NewPrefillRequestTracker(), | ||||||||||||||||||||||||||||||||||
| httpClient: httpClient, | ||||||||||||||||||||||||||||||||||
| remotePrefillThreshold: remotePrefillThreshold, | ||||||||||||||||||||||||||||||||||
| }, nil | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
@@ -126,13 +130,89 @@ func NewPrefillRequestTracker() *PrefillRequestTracker { | |||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // Route determines which pod should handle the incoming LLM request. | ||||||||||||||||||||||||||||||||||
| // It supports two routing strategies: | ||||||||||||||||||||||||||||||||||
| // | ||||||||||||||||||||||||||||||||||
| // 1. **Short-Prompt Optimization (Decode-Only)**: | ||||||||||||||||||||||||||||||||||
| // - If the input prompt length (in tokens) ≤ AIBRIX_SHORT_PROMPT_THRESHOLD (default value = 2048 tokens), | ||||||||||||||||||||||||||||||||||
| // the request is routed directly to a decode pod. | ||||||||||||||||||||||||||||||||||
| // - The decode pod performs *both* prefill (KV cache computation) and decoding locally, | ||||||||||||||||||||||||||||||||||
| // bypassing the remote prefill step entirely. | ||||||||||||||||||||||||||||||||||
| // - This reduces latency for short prompts by eliminating inter-pod communication. | ||||||||||||||||||||||||||||||||||
| // | ||||||||||||||||||||||||||||||||||
| // 2. **Standard Two-Stage Pipeline (Prefill+Decode)**: | ||||||||||||||||||||||||||||||||||
| // - For longer prompts, the system uses the traditional split architecture: | ||||||||||||||||||||||||||||||||||
| // a) A prefill pod computes the KV cache. | ||||||||||||||||||||||||||||||||||
| // b) The result is sent to a decode pod for autoregressive generation. | ||||||||||||||||||||||||||||||||||
| // | ||||||||||||||||||||||||||||||||||
| // Route decision flow: | ||||||||||||||||||||||||||||||||||
| /* | ||||||||||||||||||||||||||||||||||
| ┌───────────────────────┐ | ||||||||||||||||||||||||||||||||||
| │ Client Request │ | ||||||||||||||||||||||||||||||||||
| │ (e.g., "Hello!") │ | ||||||||||||||||||||||||||||||||||
| └──────────┬────────────┘ | ||||||||||||||||||||||||||||||||||
| │ | ||||||||||||||||||||||||||||||||||
| ▼ | ||||||||||||||||||||||||||||||||||
| ┌───────────────────────┐ | ||||||||||||||||||||||||||||||||||
| │ pdRouter.Route() │ | ||||||||||||||||||||||||||||||||||
| │ (Tokenize & Decide) │ | ||||||||||||||||||||||||||||||||||
| └──────────┬────────────┘ | ||||||||||||||||||||||||||||||||||
| │ | ||||||||||||||||||||||||||||||||||
| ┌──────────┴────────────┬───────────────────────────────┐ | ||||||||||||||||||||||||||||||||||
| │ │ │ | ||||||||||||||||||||||||||||||||||
| │ Is token count │ Yes │ No | ||||||||||||||||||||||||||||||||||
| │ ≤ threshold? ▼ ▼ | ||||||||||||||||||||||||||||||||||
| │ ┌─────────────────────┐ ┌─────────────────────┐ | ||||||||||||||||||||||||||||||||||
| │ │ Decode Pod │ │ Prefill Pod │ | ||||||||||||||||||||||||||||||||||
| │ │ (Local Prefill + │ │ (Remote Prefill, │ | ||||||||||||||||||||||||||||||||||
| │ │ Decode in one pod) │ │ KV Cache Compute) │ | ||||||||||||||||||||||||||||||||||
| │ └──────────┬──────────┘ └──────────┬──────────┘ | ||||||||||||||||||||||||||||||||||
| │ │ │ | ||||||||||||||||||||||||||||||||||
| │ │ ▼ | ||||||||||||||||||||||||||||||||||
| │ │ ┌─────────────────────┐ | ||||||||||||||||||||||||||||||||||
| │ │ │ Decode Pod │ | ||||||||||||||||||||||||||||||||||
| │ │ │ (Generation Only) │ | ||||||||||||||||||||||||||||||||||
| │ │ └──────────┬──────────┘ | ||||||||||||||||||||||||||||||||||
| │ │ │ | ||||||||||||||||||||||||||||||||||
| └────────────────────────┼──────────────────────────────┘ | ||||||||||||||||||||||||||||||||||
| │ | ||||||||||||||||||||||||||||||||||
| ▼ | ||||||||||||||||||||||||||||||||||
| ┌───────────────────────┐ | ||||||||||||||||||||||||||||||||||
| │ Response to │ | ||||||||||||||||||||||||||||||||||
| │ Client │ | ||||||||||||||||||||||||||||||||||
| └───────────────────────┘ | ||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||
| func (r pdRouter) Route(ctx *types.RoutingContext, readyPodList types.PodList) (string, error) { | ||||||||||||||||||||||||||||||||||
| // Validate engine consistency across all prefill pods | ||||||||||||||||||||||||||||||||||
| llmEngine, err := validateAndGetLLMEngine(readyPodList.All()) | ||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||
| return "", fmt.Errorf("engine validation failed for request %s: %w", ctx.RequestID, err) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // NOTE: Short-prompt optimization (bypassing remote prefill) is currently ONLY supported for vLLM. | ||||||||||||||||||||||||||||||||||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add comment "ONLY supported for vLLM" for now
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In vLLM it supports multiple roles, but it is not necessary that user configures kv_both. Will need a flag to guide gateway.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also run a benchmark to compare TTFT/TPOT/E2E P90/P99.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
sorry, I don't know much about this part. Could you explain it in more detail or give me more relevant info? |
||||||||||||||||||||||||||||||||||
| // Other engines (e.g., SGLang) do not support performing prefill+decode in a decode-only pod, | ||||||||||||||||||||||||||||||||||
| // so they must always go through the full prefill → decode pipeline. | ||||||||||||||||||||||||||||||||||
| if r.remotePrefillThreshold > 0 && llmEngine == VLLMEngine { | ||||||||||||||||||||||||||||||||||
| tokens, err := r.tokenizer.TokenizeInputText(ctx.Message) | ||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||
| klog.Warningf("Tokenization for short-prompt check failed, falling back to standard routing: %v", err) | ||||||||||||||||||||||||||||||||||
| } else if len(tokens) <= r.remotePrefillThreshold { | ||||||||||||||||||||||||||||||||||
| klog.InfoS("Short prompt detected: bypassing remote prefill", | ||||||||||||||||||||||||||||||||||
| "request_id", ctx.RequestID, | ||||||||||||||||||||||||||||||||||
| "token_count", len(tokens), | ||||||||||||||||||||||||||||||||||
| "threshold", r.remotePrefillThreshold) | ||||||||||||||||||||||||||||||||||
| // short prompt optimization (Decode-Only Path) | ||||||||||||||||||||||||||||||||||
| _, decodePod, err := r.filterPrefillDecodePods(ctx, readyPodList.All()) | ||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||
| klog.Warning("Failed to select decode pod for direct inference; falling back to prefill-decode flow", | ||||||||||||||||||||||||||||||||||
| "request_id", ctx.RequestID, "error", err) | ||||||||||||||||||||||||||||||||||
| } else if decodePod != nil { | ||||||||||||||||||||||||||||||||||
| ctx.SetTargetPod(decodePod) | ||||||||||||||||||||||||||||||||||
| return ctx.TargetAddress(), nil | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+205
to
+212
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation for the short-prompt optimization path reuses The short-prompt path should be able to select a decode pod independently of prefill pods. A new test case like the one below would fail with the current implementation: {
name: "short prompt with only decode pods available: should succeed",
readyPods: []*v1.Pod{
// No prefill pod!
{ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{PDRoleSetIdentifier: "test", PDRoleIdentifier: "decode"},
Name: "decode-1",
}, Status: v1.PodStatus{
PodIP: "127.0.0.2",
Conditions: []v1.PodCondition{{Type: v1.PodReady, Status: v1.ConditionTrue}},
}},
},
message: "hi",
shortPromptThreshold: 2048,
expectPrefillCall: false,
serverCode: http.StatusOK,
llmEngine: VLLMEngine,
expectError: false, // This would fail
expectTargetAddr: "127.0.0.2:8000",
},I recommend refactoring the pod selection logic. A new function, say
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can ignore this comment, because |
||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| prefillPod, decodePod, err := r.filterPrefillDecodePods(ctx, readyPodList.All()) | ||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||
| return "", fmt.Errorf("failed to filter prefill/decode pods for request %s: %w", ctx.RequestID, err) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How will gateway know that decode pod supports both prefill/decode role? vLLM supports it, but not sure if every engine supports this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, that's a good question. I'll do some check on that, such as whether sglang supports it, and if not, I'll add the condition that it's limited to VLLM engine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check in sglang community sgl-project/sglang#14284