Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions pkg/plugins/gateway/algorithms/pd_disaggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,27 @@ const (

defaultMaxRequest float64 = 32
defaultMaxTokenThroughputDiff float64 = 2048
defaultRemotePrefillThreshold = 512
)

var (
prefillRequestTimeout int = utils.LoadEnvInt("AIBRIX_PREFILL_REQUEST_TIMEOUT", defaultPrefillRequestTimeout)
aibrixDecodeMaxRequest float64 = utils.LoadEnvFloat("AIBRIX_DECODE_MAX_REQUEST", defaultMaxRequest)
aibrixDecodeMaxThroughputDiff float64 = utils.LoadEnvFloat("AIBRIX_DECODE_MAX_THROUGHPUT", defaultMaxTokenThroughputDiff)
remotePrefillThreshold = utils.LoadEnvInt("AIBRIX_REMOTE_PREFILL_THRESHOLD", defaultRemotePrefillThreshold)
)

func init() {
Register(RouterPD, NewPDRouter)
}

type pdRouter struct {
cache cache.Cache
tokenizer tokenizer.Tokenizer
prefixCacheIndexer *prefixcacheindexer.PrefixHashTable
prefillRequestTracker *PrefillRequestTracker
httpClient *http.Client
cache cache.Cache
tokenizer tokenizer.Tokenizer
prefixCacheIndexer *prefixcacheindexer.PrefixHashTable
prefillRequestTracker *PrefillRequestTracker
httpClient *http.Client
remotePrefillThreshold int
}

// PrefillRequestTracker manages prefill-specific request counts
Expand Down Expand Up @@ -110,11 +113,12 @@ func NewPDRouter() (types.Router, error) {
}

return pdRouter{
cache: c,
tokenizer: tokenizerObj,
prefixCacheIndexer: prefixcacheindexer.NewPrefixHashTable(),
prefillRequestTracker: NewPrefillRequestTracker(),
httpClient: httpClient,
cache: c,
tokenizer: tokenizerObj,
prefixCacheIndexer: prefixcacheindexer.NewPrefixHashTable(),
prefillRequestTracker: NewPrefillRequestTracker(),
httpClient: httpClient,
remotePrefillThreshold: remotePrefillThreshold,
}, nil
}

Expand All @@ -126,13 +130,89 @@ func NewPrefillRequestTracker() *PrefillRequestTracker {
}
}

// Route determines which pod should handle the incoming LLM request.
// It supports two routing strategies:
//
// 1. **Short-Prompt Optimization (Decode-Only)**:
// - If the input prompt length (in tokens) ≤ AIBRIX_SHORT_PROMPT_THRESHOLD (default value = 2048 tokens),
// the request is routed directly to a decode pod.
// - The decode pod performs *both* prefill (KV cache computation) and decoding locally,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How will gateway know that decode pod supports both prefill/decode role? vLLM supports it, but not sure if every engine supports this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's a good question. I'll do some check on that, such as whether sglang supports it, and if not, I'll add the condition that it's limited to VLLM engine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check in sglang community sgl-project/sglang#14284

// bypassing the remote prefill step entirely.
// - This reduces latency for short prompts by eliminating inter-pod communication.
//
// 2. **Standard Two-Stage Pipeline (Prefill+Decode)**:
// - For longer prompts, the system uses the traditional split architecture:
// a) A prefill pod computes the KV cache.
// b) The result is sent to a decode pod for autoregressive generation.
//
// Route decision flow:
/*
┌───────────────────────┐
│ Client Request │
│ (e.g., "Hello!") │
└──────────┬────────────┘
┌───────────────────────┐
│ pdRouter.Route() │
│ (Tokenize & Decide) │
└──────────┬────────────┘
┌──────────┴────────────┬───────────────────────────────┐
│ │ │
│ Is token count │ Yes │ No
│ ≤ threshold? ▼ ▼
│ ┌─────────────────────┐ ┌─────────────────────┐
│ │ Decode Pod │ │ Prefill Pod │
│ │ (Local Prefill + │ │ (Remote Prefill, │
│ │ Decode in one pod) │ │ KV Cache Compute) │
│ └──────────┬──────────┘ └──────────┬──────────┘
│ │ │
│ │ ▼
│ │ ┌─────────────────────┐
│ │ │ Decode Pod │
│ │ │ (Generation Only) │
│ │ └──────────┬──────────┘
│ │ │
└────────────────────────┼──────────────────────────────┘
┌───────────────────────┐
│ Response to │
│ Client │
└───────────────────────┘
*/
func (r pdRouter) Route(ctx *types.RoutingContext, readyPodList types.PodList) (string, error) {
// Validate engine consistency across all prefill pods
llmEngine, err := validateAndGetLLMEngine(readyPodList.All())
if err != nil {
return "", fmt.Errorf("engine validation failed for request %s: %w", ctx.RequestID, err)
}

// NOTE: Short-prompt optimization (bypassing remote prefill) is currently ONLY supported for vLLM.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment "ONLY supported for vLLM" for now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if r.remotePrefillThreshold > 0 && llmEngine == VLLMEngine {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In vLLM it supports multiple roles, but it is not necessary that user configures kv_both. Will need a flag to guide gateway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also run a benchmark to compare TTFT/TPOT/E2E P90/P99.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In vLLM it supports multiple roles, but it is not necessary that user configures kv_both. Will need a flag to guide gateway.

sorry, I don't know much about this part. Could you explain it in more detail or give me more relevant info?

// Other engines (e.g., SGLang) do not support performing prefill+decode in a decode-only pod,
// so they must always go through the full prefill → decode pipeline.
if r.remotePrefillThreshold > 0 && llmEngine == VLLMEngine {
tokens, err := r.tokenizer.TokenizeInputText(ctx.Message)
if err != nil {
klog.Warningf("Tokenization for short-prompt check failed, falling back to standard routing: %v", err)
} else if len(tokens) <= r.remotePrefillThreshold {
klog.InfoS("Short prompt detected: bypassing remote prefill",
"request_id", ctx.RequestID,
"token_count", len(tokens),
"threshold", r.remotePrefillThreshold)
// short prompt optimization (Decode-Only Path)
_, decodePod, err := r.filterPrefillDecodePods(ctx, readyPodList.All())
if err != nil {
klog.Warning("Failed to select decode pod for direct inference; falling back to prefill-decode flow",
"request_id", ctx.RequestID, "error", err)
} else if decodePod != nil {
ctx.SetTargetPod(decodePod)
return ctx.TargetAddress(), nil
}
Comment on lines +205 to +212
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for the short-prompt optimization path reuses filterPrefillDecodePods, which is designed to find a pair of prefill and decode pods. This creates a bug: if no prefill pods are available, filterPrefillDecodePods will return an error, causing the short-prompt path to fail and fall back to the standard prefill-decode flow. The standard flow will then also fail for the same reason (no prefill pods).

The short-prompt path should be able to select a decode pod independently of prefill pods. A new test case like the one below would fail with the current implementation:

{
    name: "short prompt with only decode pods available: should succeed",
    readyPods: []*v1.Pod{
        // No prefill pod!
        {ObjectMeta: metav1.ObjectMeta{
            Labels: map[string]string{PDRoleSetIdentifier: "test", PDRoleIdentifier: "decode"},
            Name:   "decode-1",
        }, Status: v1.PodStatus{
            PodIP:      "127.0.0.2",
            Conditions: []v1.PodCondition{{Type: v1.PodReady, Status: v1.ConditionTrue}},
        }},
    },
    message:              "hi",
    shortPromptThreshold: 2048,
    expectPrefillCall:    false,
    serverCode:           http.StatusOK,
    llmEngine:            VLLMEngine,
    expectError:          false, // This would fail
    expectTargetAddr:     "127.0.0.2:8000",
},

I recommend refactoring the pod selection logic. A new function, say selectDecodePod, should be created to encapsulate the logic for selecting only a decode pod. This would involve extracting the decode-pod-specific selection logic from filterPrefillDecodePods. The Route function can then be updated to call this new function for the short-prompt path.

Suggested change
_, decodePod, err := r.filterPrefillDecodePods(ctx, readyPodList.All())
if err != nil {
klog.Warning("Failed to select decode pod for direct inference; falling back to prefill-decode flow",
"request_id", ctx.RequestID, "error", err)
} else if decodePod != nil {
ctx.SetTargetPod(decodePod)
return ctx.TargetAddress(), nil
}
decodePod, err := r.selectDecodePod(ctx, readyPodList.All())
if err != nil {
klog.Warning("Failed to select decode pod for direct inference; falling back to prefill-decode flow",
"request_id", ctx.RequestID, "error", err)
} else {
ctx.SetTargetPod(decodePod)
return ctx.TargetAddress(), nil
}

Copy link
Collaborator Author

@googs1025 googs1025 Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can ignore this comment, because filterPrefillDecodePods check both prefill and decode pod, this make sense in every round request

}
}

prefillPod, decodePod, err := r.filterPrefillDecodePods(ctx, readyPodList.All())
if err != nil {
return "", fmt.Errorf("failed to filter prefill/decode pods for request %s: %w", ctx.RequestID, err)
Expand Down
Loading