Skip to content

Commit db2b7ce

Browse files
authored
Implement PrepareDataPlugin for prefix cache match plugin (#1942)
* Add PrepareRequestData method for the prefix cache plugin * Add prefix cache match scorer * Ensure prefix cache plugin implements all preprae data plugin methods * Enable prepare data plugins behind a feature flag * Add static check to ensure plugin implements PrepareDataPlugin interface * Rename prefix cache match * Add feature gate for prepare data plugin * Update prefix cache match to have total and match length * Add more tests, register prefix cache scorer and address other review comments * Move prefix cache scorer out of this PR * Fix rebase errors * Update directory structure * Update TODOs * Validate type as well in dag validation * Fix rebase errors, add TODO
1 parent 2fba74b commit db2b7ce

File tree

11 files changed

+198
-17
lines changed

11 files changed

+198
-17
lines changed

cmd/epp/runner/runner.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ func (r *Runner) Run(ctx context.Context) error {
245245
}
246246

247247
// --- Setup Datastore ---
248-
epf, err := r.setupMetricsCollection(setupLog, r.featureGates[datalayer.FeatureGate])
248+
epf, err := r.setupMetricsCollection(setupLog, r.featureGates[datalayer.ExperimentalDatalayerFeatureGate])
249249
if err != nil {
250250
return err
251251
}
@@ -387,7 +387,7 @@ func (r *Runner) Run(ctx context.Context) error {
387387
MetricsStalenessThreshold: *metricsStalenessThreshold,
388388
Director: director,
389389
SaturationDetector: saturationDetector,
390-
UseExperimentalDatalayerV2: r.featureGates[datalayer.FeatureGate], // pluggable data layer feature flag
390+
UseExperimentalDatalayerV2: r.featureGates[datalayer.ExperimentalDatalayerFeatureGate], // pluggable data layer feature flag
391391
}
392392
if err := serverRunner.SetupWithManager(ctx, mgr); err != nil {
393393
setupLog.Error(err, "Failed to setup EPP controllers")
@@ -479,8 +479,9 @@ func (r *Runner) parseConfigurationPhaseOne(ctx context.Context) (*configapi.End
479479
}
480480
}
481481

482-
loader.RegisterFeatureGate(datalayer.FeatureGate)
482+
loader.RegisterFeatureGate(datalayer.ExperimentalDatalayerFeatureGate)
483483
loader.RegisterFeatureGate(flowcontrol.FeatureGate)
484+
loader.RegisterFeatureGate(datalayer.PrepareDataPluginsFeatureGate)
484485

485486
r.registerInTreePlugins()
486487

@@ -520,10 +521,16 @@ func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *conf
520521

521522
// Add requestControl plugins
522523
r.requestControlConfig.AddPlugins(handle.GetAllPlugins()...)
524+
523525
// Sort prepare data plugins in DAG order (topological sort). Also check prepare data plugins for cycles.
524526
if r.requestControlConfig.PrepareDataPluginGraph() != nil {
525527
return nil, errors.New("failed to load the configuration - prepare data plugins have cyclic dependencies")
526528
}
529+
// TODO(#1970): Remove feature gate check once prepare data plugins are stable.
530+
if !r.featureGates[datalayer.PrepareDataPluginsFeatureGate] {
531+
// If the feature gate is disabled, clear any prepare data plugins so they are not used.
532+
r.requestControlConfig.WithPrepareDataPlugins()
533+
}
527534

528535
// Handler deprecated configuration options
529536
r.deprecatedConfigurationHelper(cfg, logger)
@@ -545,7 +552,7 @@ func (r *Runner) deprecatedConfigurationHelper(cfg *config.Config, logger logr.L
545552

546553
if _, ok := os.LookupEnv(enableExperimentalDatalayerV2); ok {
547554
logger.Info("Enabling the experimental Data Layer V2 using environment variables is deprecated and will be removed in next version")
548-
r.featureGates[datalayer.FeatureGate] = env.GetEnvBool(enableExperimentalDatalayerV2, false, logger)
555+
r.featureGates[datalayer.ExperimentalDatalayerFeatureGate] = env.GetEnvBool(enableExperimentalDatalayerV2, false, logger)
549556
}
550557
if _, ok := os.LookupEnv(enableExperimentalFlowControlLayer); ok {
551558
logger.Info("Enabling the experimental Flow Control layer using environment variables is deprecated and will be removed in next version")

pkg/epp/config/loader/configloader.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func InstantiateAndConfigure(
9797
}
9898

9999
featureGates := loadFeatureConfig(rawConfig.FeatureGates)
100-
dataConfig, err := buildDataLayerConfig(rawConfig.Data, featureGates[datalayer.FeatureGate], handle)
100+
dataConfig, err := buildDataLayerConfig(rawConfig.Data, featureGates[datalayer.ExperimentalDatalayerFeatureGate], handle)
101101
if err != nil {
102102
return nil, fmt.Errorf("data layer config build failed: %w", err)
103103
}

pkg/epp/config/loader/configloader_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func TestLoadRawConfiguration(t *testing.T) {
5858
t.Parallel()
5959

6060
// Register known feature gates for validation.
61-
RegisterFeatureGate(datalayer.FeatureGate)
61+
RegisterFeatureGate(datalayer.ExperimentalDatalayerFeatureGate)
6262

6363
tests := []struct {
6464
name string
@@ -90,7 +90,7 @@ func TestLoadRawConfiguration(t *testing.T) {
9090
},
9191
},
9292
},
93-
FeatureGates: configapi.FeatureGates{datalayer.FeatureGate},
93+
FeatureGates: configapi.FeatureGates{datalayer.ExperimentalDatalayerFeatureGate},
9494
SaturationDetector: &configapi.SaturationDetector{
9595
QueueDepthThreshold: 10,
9696
KVCacheUtilThreshold: 0.8,
@@ -150,7 +150,7 @@ func TestInstantiateAndConfigure(t *testing.T) {
150150
// Not parallel because it modifies global plugin registry.
151151
registerTestPlugins(t)
152152

153-
RegisterFeatureGate(datalayer.FeatureGate)
153+
RegisterFeatureGate(datalayer.ExperimentalDatalayerFeatureGate)
154154

155155
tests := []struct {
156156
name string

pkg/epp/datalayer/factory.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ import (
2626
)
2727

2828
const (
29-
FeatureGate = "dataLayer"
29+
ExperimentalDatalayerFeatureGate = "dataLayer"
30+
PrepareDataPluginsFeatureGate = "prepareDataPlugins"
3031
)
3132

3233
// PoolInfo represents the DataStore information needed for endpoints.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package approximateprefix
18+
19+
import (
20+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
21+
)
22+
23+
const (
24+
PrefixCacheMatchInfoKey = "PrefixCacheMatchInfoKey"
25+
)
26+
27+
type PrefixCacheMatchInfo struct {
28+
matchLength int
29+
totalBlocks int
30+
}
31+
32+
func NewPrefixCacheMatchInfo(matchLen int, blockHashLen int) *PrefixCacheMatchInfo {
33+
return &PrefixCacheMatchInfo{
34+
matchLength: matchLen,
35+
totalBlocks: blockHashLen,
36+
}
37+
}
38+
39+
func (p *PrefixCacheMatchInfo) MatchLength() int {
40+
return p.matchLength
41+
}
42+
43+
func (p *PrefixCacheMatchInfo) TotalLength() int {
44+
return p.totalBlocks
45+
}
46+
47+
func (p *PrefixCacheMatchInfo) Clone() datalayer.Cloneable {
48+
return &PrefixCacheMatchInfo{
49+
matchLength: p.matchLength,
50+
totalBlocks: p.totalBlocks,
51+
}
52+
}

pkg/epp/requestcontrol/dag.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import (
2323

2424
// buildDAG builds a dependency graph among data preparation plugins based on their
2525
// produced and consumed data keys.
26-
func buildDAG(plugins []PrepareDataPlugin) map[string][]string {
26+
func buildDAG(plugins []PrepareDataPlugin) (map[string][]string, error) {
2727
dag := make(map[string][]string)
2828
for _, plugin := range plugins {
2929
dag[plugin.TypedName().String()] = []string{}
@@ -36,11 +36,14 @@ func buildDAG(plugins []PrepareDataPlugin) map[string][]string {
3636
}
3737
// Check whether plugin[i] produces something consumed by plugin[j]. In that case, j depends on i.
3838
if plugins[i].Produces() != nil && plugins[j].Consumes() != nil {
39-
// For all the keys produced by plugin i, check if plugin j consumes any of them.
40-
// If yes, then j depends on i.
41-
for producedKey := range plugins[i].Produces() {
39+
for producedKey, producedData := range plugins[i].Produces() {
4240
// If plugin j consumes the produced key, then j depends on i. We can break after the first match.
43-
if _, ok := plugins[j].Consumes()[producedKey]; ok {
41+
if consumedData, ok := plugins[j].Consumes()[producedKey]; ok {
42+
// Check types are same. Reflection is avoided here for simplicity.
43+
// TODO(#1985): Document this detail in IGW docs.
44+
if producedData != consumedData {
45+
return nil, errors.New("data type mismatch between produced and consumed data for key: " + producedKey)
46+
}
4447
iPluginName := plugins[i].TypedName().String()
4548
jPluginName := plugins[j].TypedName().String()
4649
dag[jPluginName] = append(dag[jPluginName], iPluginName)
@@ -50,7 +53,7 @@ func buildDAG(plugins []PrepareDataPlugin) map[string][]string {
5053
}
5154
}
5255
}
53-
return dag
56+
return dag, nil
5457
}
5558

5659
// sortPlugins builds the dependency graph and returns the plugins ordered in topological order.

pkg/epp/requestcontrol/dag_test.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ func TestPrepareDataGraph(t *testing.T) {
6161
pluginX := &mockPrepareRequestDataP{name: "X", produces: map[string]any{"keyX": nil}, consumes: map[string]any{"keyY": nil}}
6262
pluginY := &mockPrepareRequestDataP{name: "Y", produces: map[string]any{"keyY": nil}, consumes: map[string]any{"keyX": nil}}
6363

64+
// Data type mismatch plugin.
65+
pluginZ1 := &mockPrepareRequestDataP{name: "Z1", produces: map[string]any{"keyZ": int(0)}}
66+
pluginZ2 := &mockPrepareRequestDataP{name: "Z2", consumes: map[string]any{"keyZ": string("")}}
67+
6468
testCases := []struct {
6569
name string
6670
plugins []PrepareDataPlugin
@@ -109,11 +113,23 @@ func TestPrepareDataGraph(t *testing.T) {
109113
expectedDAG: nil,
110114
expectError: true,
111115
},
116+
{
117+
name: "Data type mismatch between produced and consumed data",
118+
plugins: []PrepareDataPlugin{pluginZ1, pluginZ2},
119+
expectedDAG: nil,
120+
expectError: true,
121+
},
112122
}
113123

114124
for _, tc := range testCases {
115125
t.Run(tc.name, func(t *testing.T) {
116-
dag := buildDAG(tc.plugins)
126+
dag, err := buildDAG(tc.plugins)
127+
if err != nil {
128+
if tc.expectError {
129+
assert.Error(t, err)
130+
return
131+
}
132+
}
117133
orderedPlugins, err := sortPlugins(dag, tc.plugins)
118134

119135
if tc.expectError {

pkg/epp/requestcontrol/director.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
347347

348348
func (d *Director) runPrepareDataPlugins(ctx context.Context,
349349
request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) error {
350+
if len(d.requestControlPlugins.prepareDataPlugins) == 0 {
351+
return nil
352+
}
350353
return prepareDataPluginsWithTimeout(prepareDataTimeout, d.requestControlPlugins.prepareDataPlugins, ctx, request, pods)
351354
}
352355

pkg/epp/requestcontrol/request_control_config.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,14 @@ func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) {
108108
// PrepareDataPluginGraph creates data dependency graph and sorts the plugins in topological order.
109109
// If a cycle is detected, it returns an error.
110110
func (c *Config) PrepareDataPluginGraph() error {
111-
dag := buildDAG(c.prepareDataPlugins)
111+
// TODO(#1988): Add all producer and consumer plugins to the graph.
112+
if len(c.prepareDataPlugins) == 0 {
113+
return nil
114+
}
115+
dag, err := buildDAG(c.prepareDataPlugins)
116+
if err != nil {
117+
return err
118+
}
112119
plugins, err := sortPlugins(dag, c.prepareDataPlugins)
113120
if err != nil {
114121
return err

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
k8stypes "k8s.io/apimachinery/pkg/types"
2929
"sigs.k8s.io/controller-runtime/pkg/log"
3030

31+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix"
3132
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3233
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
3334
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
@@ -206,6 +207,30 @@ func (p *Plugin) WithName(name string) *Plugin {
206207
return p
207208
}
208209

210+
func (p *Plugin) Produces() map[string]any {
211+
return map[string]any{approximateprefix.PrefixCacheMatchInfoKey: approximateprefix.PrefixCacheMatchInfo{}}
212+
}
213+
214+
func (p *Plugin) Consumes() map[string]any {
215+
return map[string]any{}
216+
}
217+
218+
// PrepareRequestData hashes prompt, finds longest prefix match and stores it in pod as attribute.
219+
func (p *Plugin) PrepareRequestData(ctx context.Context, request *types.LLMRequest, pods []types.Pod) error {
220+
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config), p.config.MaxPrefixBlocksToMatch)
221+
state := &SchedulingContextState{
222+
PrefixHashes: hashes,
223+
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
224+
}
225+
total := len(state.PrefixHashes)
226+
227+
for _, pod := range pods {
228+
matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)]
229+
pod.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(matchLen, total))
230+
}
231+
return nil
232+
}
233+
209234
// Score returns the scoring result for the given list of pods based on context.
210235
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
211236
// pre score step, hashing prompt and find longest prefix match.

0 commit comments

Comments
 (0)