Skip to content

Commit 5033dc2

Browse files
Add Grammar Cache for XGrammar Backend (#936)
1 parent 8b1c38e commit 5033dc2

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import copy
2+
import functools
13
import torch
24
from typing import List, Tuple
35

@@ -34,6 +36,22 @@ def init_custom(self):
3436
eos_token_ids = []
3537
eos_token_ids.append(self.tokenizer.eos_token_id)
3638
eos_token_ids.extend(self.args.eos_id)
39+
40+
@functools.lru_cache(maxsize=200)
41+
def get_cached_grammar(type: str, grammar: str):
42+
logger.info(f"grammar cache miss for {type}: '{grammar}'")
43+
try:
44+
if type == "grammar":
45+
return self.xgrammar_compiler.compile_grammar(grammar)
46+
elif type == "schema":
47+
return self.xgrammar_compiler.compile_json_schema(grammar)
48+
else:
49+
raise ValueError(f"Unknown xgrammar type: {type}")
50+
except Exception as e:
51+
logger.error(f"Failed to compile {type}: {e}")
52+
raise
53+
54+
self.get_cached_grammar = get_cached_grammar
3755
return
3856

3957
@calculate_time(show=False, min_cost_ms=300)
@@ -149,10 +167,10 @@ def _init_req_xgrammer_matcher_infos(self, run_reqs: List[InferReq]):
149167
sample_params = run_obj.sampling_param
150168
if sample_params.guided_grammar is not None:
151169
if not hasattr(sample_params, "xgrammar_matcher"):
152-
xgrammar_compiled_grammar = self.xgrammar_compiler.compile_grammar(sample_params.guided_grammar)
153-
sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar)
170+
ctx = self.get_cached_grammar("grammar", sample_params.guided_grammar)
171+
sample_params.xgrammar_matcher = xgr.GrammarMatcher(ctx)
154172
elif sample_params.guided_json is not None:
155173
if not hasattr(sample_params, "xgrammar_matcher"):
156-
xgrammar_compiled_grammar = self.xgrammar_compiler.compile_json_schema(sample_params.guided_json)
157-
sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar)
174+
ctx = self.get_cached_grammar("schema", sample_params.guided_json)
175+
sample_params.xgrammar_matcher = xgr.GrammarMatcher(ctx)
158176
return

0 commit comments

Comments
 (0)