|
| 1 | +import copy |
| 2 | +import functools |
1 | 3 | import torch
|
2 | 4 | from typing import List, Tuple
|
3 | 5 |
|
@@ -34,6 +36,22 @@ def init_custom(self):
|
34 | 36 | eos_token_ids = []
|
35 | 37 | eos_token_ids.append(self.tokenizer.eos_token_id)
|
36 | 38 | eos_token_ids.extend(self.args.eos_id)
|
| 39 | + |
| 40 | + @functools.lru_cache(maxsize=200) |
| 41 | + def get_cached_grammar(type: str, grammar: str): |
| 42 | + logger.info(f"grammar cache miss for {type}: '{grammar}'") |
| 43 | + try: |
| 44 | + if type == "grammar": |
| 45 | + return self.xgrammar_compiler.compile_grammar(grammar) |
| 46 | + elif type == "schema": |
| 47 | + return self.xgrammar_compiler.compile_json_schema(grammar) |
| 48 | + else: |
| 49 | + raise ValueError(f"Unknown xgrammar type: {type}") |
| 50 | + except Exception as e: |
| 51 | + logger.error(f"Failed to compile {type}: {e}") |
| 52 | + raise |
| 53 | + |
| 54 | + self.get_cached_grammar = get_cached_grammar |
37 | 55 | return
|
38 | 56 |
|
39 | 57 | @calculate_time(show=False, min_cost_ms=300)
|
@@ -149,10 +167,10 @@ def _init_req_xgrammer_matcher_infos(self, run_reqs: List[InferReq]):
|
149 | 167 | sample_params = run_obj.sampling_param
|
150 | 168 | if sample_params.guided_grammar is not None:
|
151 | 169 | if not hasattr(sample_params, "xgrammar_matcher"):
|
152 |
| - xgrammar_compiled_grammar = self.xgrammar_compiler.compile_grammar(sample_params.guided_grammar) |
153 |
| - sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar) |
| 170 | + ctx = self.get_cached_grammar("grammar", sample_params.guided_grammar) |
| 171 | + sample_params.xgrammar_matcher = xgr.GrammarMatcher(ctx) |
154 | 172 | elif sample_params.guided_json is not None:
|
155 | 173 | if not hasattr(sample_params, "xgrammar_matcher"):
|
156 |
| - xgrammar_compiled_grammar = self.xgrammar_compiler.compile_json_schema(sample_params.guided_json) |
157 |
| - sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar) |
| 174 | + ctx = self.get_cached_grammar("schema", sample_params.guided_json) |
| 175 | + sample_params.xgrammar_matcher = xgr.GrammarMatcher(ctx) |
158 | 176 | return
|
0 commit comments