Skip to content

Commit 8a1e47b

Browse files
committed
feat(pd_router): implement dual routing strategies for pd router
Signed-off-by: CYJiang <[email protected]>
1 parent 0ff33af commit 8a1e47b

File tree

2 files changed

+302
-38
lines changed

2 files changed

+302
-38
lines changed

pkg/plugins/gateway/algorithms/pd_disaggregation.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ const (
5656

5757
defaultMaxRequest float64 = 32
5858
defaultMaxTokenThroughputDiff float64 = 2048
59+
defaultShortPromptThreshold = 512
5960
)
6061

6162
var (
6263
prefillRequestTimeout int = utils.LoadEnvInt("AIBRIX_PREFILL_REQUEST_TIMEOUT", defaultPrefillRequestTimeout)
6364
aibrixDecodeMaxRequest float64 = utils.LoadEnvFloat("AIBRIX_DECODE_MAX_REQUEST", defaultMaxRequest)
6465
aibrixDecodeMaxThroughputDiff float64 = utils.LoadEnvFloat("AIBRIX_DECODE_MAX_THROUGHPUT", defaultMaxTokenThroughputDiff)
66+
shortPromptThreshold = utils.LoadEnvInt("AIBRIX_SHORT_PROMPT_THRESHOLD", defaultShortPromptThreshold)
6567
)
6668

6769
func init() {
@@ -74,6 +76,7 @@ type pdRouter struct {
7476
prefixCacheIndexer *prefixcacheindexer.PrefixHashTable
7577
prefillRequestTracker *PrefillRequestTracker
7678
httpClient *http.Client
79+
shortPromptThreshold int
7780
}
7881

7982
// PrefillRequestTracker manages prefill-specific request counts
@@ -115,6 +118,7 @@ func NewPDRouter() (types.Router, error) {
115118
prefixCacheIndexer: prefixcacheindexer.NewPrefixHashTable(),
116119
prefillRequestTracker: NewPrefillRequestTracker(),
117120
httpClient: httpClient,
121+
shortPromptThreshold: shortPromptThreshold,
118122
}, nil
119123
}
120124

@@ -126,7 +130,79 @@ func NewPrefillRequestTracker() *PrefillRequestTracker {
126130
}
127131
}
128132

133+
// Route determines which pod should handle the incoming LLM request.
134+
// It supports two routing strategies:
135+
//
136+
// 1. **Short-Prompt Optimization (Decode-Only)**:
137+
// - If the input prompt length (in tokens) ≤ AIBRIX_SHORT_PROMPT_THRESHOLD (default value = 2048 tokens),
138+
// the request is routed directly to a decode pod.
139+
// - The decode pod performs *both* prefill (KV cache computation) and decoding locally,
140+
// bypassing the remote prefill step entirely.
141+
// - This reduces latency for short prompts by eliminating inter-pod communication.
142+
//
143+
// 2. **Standard Two-Stage Pipeline (Prefill+Decode)**:
144+
// - For longer prompts, the system uses the traditional split architecture:
145+
// a) A prefill pod computes the KV cache.
146+
// b) The result is sent to a decode pod for autoregressive generation.
147+
//
148+
// Route decision flow:
149+
/*
150+
┌───────────────────────┐
151+
│ Client Request │
152+
│ (e.g., "Hello!") │
153+
└──────────┬────────────┘
154+
155+
156+
┌───────────────────────┐
157+
│ pdRouter.Route() │
158+
│ (Tokenize & Decide) │
159+
└──────────┬────────────┘
160+
161+
┌──────────┴────────────┬───────────────────────────────┐
162+
│ │ │
163+
│ Is token count │ Yes │ No
164+
│ ≤ threshold? ▼ ▼
165+
│ ┌─────────────────────┐ ┌─────────────────────┐
166+
│ │ Decode Pod │ │ Prefill Pod │
167+
│ │ (Local Prefill + │ │ (Remote Prefill, │
168+
│ │ Decode in one pod) │ │ KV Cache Compute) │
169+
│ └──────────┬──────────┘ └──────────┬──────────┘
170+
│ │ │
171+
│ │ ▼
172+
│ │ ┌─────────────────────┐
173+
│ │ │ Decode Pod │
174+
│ │ │ (Generation Only) │
175+
│ │ └──────────┬──────────┘
176+
│ │ │
177+
└────────────────────────┼──────────────────────────────┘
178+
179+
180+
┌───────────────────────┐
181+
│ Response to │
182+
│ Client │
183+
└───────────────────────┘
184+
*/
129185
func (r pdRouter) Route(ctx *types.RoutingContext, readyPodList types.PodList) (string, error) {
186+
if r.shortPromptThreshold > 0 {
187+
tokens, err := r.tokenizer.TokenizeInputText(ctx.Message)
188+
if err != nil {
189+
klog.Warningf("Tokenization for short-prompt check failed, falling back to standard routing: %v", err)
190+
} else if len(tokens) <= r.shortPromptThreshold {
191+
klog.InfoS("Short prompt detected: bypassing remote prefill",
192+
"request_id", ctx.RequestID,
193+
"token_count", len(tokens),
194+
"threshold", r.shortPromptThreshold)
195+
// short prompt optimization (Decode-Only Path)
196+
_, decodePod, err := r.filterPrefillDecodePods(ctx, readyPodList.All())
197+
if err != nil {
198+
klog.Warning("Failed to select decode pod for direct inference; falling back to prefill-decode flow",
199+
"request_id", ctx.RequestID, "error", err)
200+
} else if decodePod != nil {
201+
ctx.SetTargetPod(decodePod)
202+
return ctx.TargetAddress(), nil
203+
}
204+
}
205+
}
130206
// Validate engine consistency across all prefill pods
131207
llmEngine, err := validateAndGetLLMEngine(readyPodList.All())
132208
if err != nil {

0 commit comments

Comments
 (0)