Skip to content

Commit d3149e8

Browse files
authored
feat: Optimize load_policy_line to avoid quadratic individual-character loop (#355)
1 parent b9dba51 commit d3149e8

File tree

4 files changed

+122
-16
lines changed

4 files changed

+122
-16
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ jobs:
3939
tests/benchmarks/benchmark_model.py
4040
tests/benchmarks/benchmark_management_api.py
4141
tests/benchmarks/benchmark_role_manager.py
42+
tests/benchmarks/benchmark_adapter.py
4243

4344
- name: Upload coverage data to coveralls.io
4445
run: coveralls --service=github

casbin/persist/adapter.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,56 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import re
1516

16-
def load_policy_line(line, model):
17-
"""loads a text line as a policy rule to model."""
17+
_INTERESTING_TOKENS_RE = re.compile(r"[,\[\]\(\)]")
18+
19+
20+
def _extract_tokens(line):
21+
"""Return the list of 'tokens' from the line, or None if this line has none"""
1822

1923
if line == "":
20-
return
24+
return None
2125

2226
if line[:1] == "#":
23-
return
27+
return None
2428

2529
stack = []
2630
tokens = []
27-
for c in line:
31+
32+
# The tokens are separated by commas, but we support nesting so a naive `line.split(",")` is
33+
# wrong. E.g. `abc(def, ghi), jkl` is two tokens: `abc(def, ghi)` and `jkl`. We do this by
34+
# iterating over the locations of any tokens of interest, and either:
35+
#
36+
# - [](): adjust the nesting depth
37+
# - ,: slice the line to save the token, if the , is at the top-level, outside all []()
38+
#
39+
# `start_idx` represents the start of the current token, that we haven't seen a `,` for yet.
40+
start_idx = 0
41+
for match in _INTERESTING_TOKENS_RE.finditer(line):
42+
c = match.group()
2843
if c == "[" or c == "(":
2944
stack.append(c)
30-
tokens[-1] += c
3145
elif c == "]" or c == ")":
3246
stack.pop()
33-
tokens[-1] += c
34-
elif c == "," and len(stack) == 0:
35-
tokens.append("")
36-
else:
37-
if len(tokens) == 0:
38-
tokens.append(c)
39-
else:
40-
tokens[-1] += c
41-
42-
tokens = [x.strip() for x in tokens]
47+
elif not stack:
48+
# must be a comma outside of any nesting: we've found the end of a top level token so
49+
# save that and start a new one
50+
tokens.append(line[start_idx : match.start()].strip())
51+
start_idx = match.end()
52+
53+
# trailing token after the last ,
54+
tokens.append(line[start_idx:].strip())
55+
56+
return tokens
57+
58+
59+
def load_policy_line(line, model):
60+
"""loads a text line as a policy rule to model."""
61+
62+
tokens = _extract_tokens(line)
63+
if tokens is None:
64+
return
4365

4466
key = tokens[0]
4567
sec = key[0]

tests/benchmarks/benchmark_adapter.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from casbin.persist.adapter import _extract_tokens
2+
3+
4+
def _benchmark_extract_tokens(benchmark, line):
5+
@benchmark
6+
def run_benchmark():
7+
_extract_tokens(line)
8+
9+
10+
def test_benchmark_extract_tokens_short_simple(benchmark):
11+
_benchmark_extract_tokens(benchmark, "abc,def,ghi")
12+
13+
14+
def test_benchmark_extract_tokens_long_simple(benchmark):
15+
# fixed UUIDs for length and to be similar to "real world" usage of UUIDs
16+
_benchmark_extract_tokens(
17+
benchmark,
18+
"00000000-0000-0000-0000-000000000000,00000000-0000-0000-0000-000000000001,00000000-0000-0000-0000-000000000002",
19+
)
20+
21+
22+
def test_benchmark_extract_tokens_short_nested(benchmark):
23+
_benchmark_extract_tokens(benchmark, "abc(def,ghi),jkl(mno,pqr)")
24+
25+
26+
def test_benchmark_extract_tokens_long_nested(benchmark):
27+
_benchmark_extract_tokens(
28+
benchmark,
29+
"00000000-0000-0000-0000-000000000000(00000000-0000-0000-0000-000000000001,00000000-0000-0000-0000-000000000002),00000000-0000-0000-0000-000000000003(00000000-0000-0000-0000-000000000004,00000000-0000-0000-0000-000000000005)",
30+
)

tests/persist/test_adapter.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from casbin.persist.adapter import _extract_tokens
2+
from tests import TestCaseBase
3+
4+
5+
class TestExtractTokens(TestCaseBase):
6+
def test_ignore_lines(self):
7+
self.assertIsNone(_extract_tokens("")) # empty
8+
self.assertIsNone(_extract_tokens("# comment"))
9+
10+
def test_simple_lines(self):
11+
# split on top-level commas, strip whitespace from start and end
12+
self.assertEqual(_extract_tokens("one"), ["one"])
13+
self.assertEqual(_extract_tokens("one,two"), ["one", "two"])
14+
self.assertEqual(_extract_tokens(" ignore \t,\t external, spaces "), ["ignore", "external", "spaces"])
15+
16+
self.assertEqual(_extract_tokens("internal spaces preserved"), ["internal spaces preserved"])
17+
18+
def test_nested_lines(self):
19+
# basic nesting within a single token
20+
self.assertEqual(
21+
_extract_tokens("outside1()"),
22+
["outside1()"],
23+
)
24+
self.assertEqual(
25+
_extract_tokens("outside1(inside1())"),
26+
["outside1(inside1())"],
27+
)
28+
29+
# split on top-level commas, but not on internal ones
30+
self.assertEqual(
31+
_extract_tokens("outside1(inside1(), inside2())"),
32+
["outside1(inside1(), inside2())"],
33+
)
34+
self.assertEqual(
35+
_extract_tokens("outside1(inside1(), inside2(inside3(), inside4()))"),
36+
["outside1(inside1(), inside2(inside3(), inside4()))"],
37+
)
38+
self.assertEqual(
39+
_extract_tokens("outside1(inside1(), inside2()), outside2(inside3(), inside4())"),
40+
["outside1(inside1(), inside2())", "outside2(inside3(), inside4())"],
41+
)
42+
43+
# different delimiters
44+
self.assertEqual(
45+
_extract_tokens(
46+
"all_square[inside1[], inside2[]],square_and_parens[inside1(), inside2()],parens_and_square(inside1[], inside2[])"
47+
),
48+
[
49+
"all_square[inside1[], inside2[]]",
50+
"square_and_parens[inside1(), inside2()]",
51+
"parens_and_square(inside1[], inside2[])",
52+
],
53+
)

0 commit comments

Comments
 (0)