Skip to content

Rethinking the import system #1516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
root = true

[*.py]
indent_style = space
indent_size = 4
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ tags
docs/_build
docs/examples
docs/sg_execution_times.rst
/venv
12 changes: 11 additions & 1 deletion lark/lark.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class LarkOptions(Serialize):
edit_terminals: Optional[Callable[[TerminalDef], TerminalDef]]
import_paths: 'List[Union[str, Callable[[Union[None, str, PackageResource], str], Tuple[str, str]]]]'
source_path: Optional[str]
legacy_import: bool

OPTIONS_DOC = r"""
**=== General Options ===**
Expand Down Expand Up @@ -107,6 +108,8 @@ class LarkOptions(Serialize):
Prevent the tree builder from automagically removing "punctuation" tokens (Default: ``False``)
tree_class
Lark will produce trees comprised of instances of this class instead of the default ``lark.Tree``.
legacy_import
Lark will use the old import system where imported rules are not namespaced.

**=== Algorithm Options ===**

Expand Down Expand Up @@ -183,6 +186,7 @@ class LarkOptions(Serialize):
'import_paths': [],
'source_path': None,
'_plugins': {},
'legacy_import': True,
}

def __init__(self, options_dict: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -354,7 +358,13 @@ def __init__(self, grammar: 'Union[Grammar, str, IO[str]]', **options) -> None:


# Parse the grammar file and compose the grammars
self.grammar, used_files = load_grammar(grammar, self.source_path, self.options.import_paths, self.options.keep_all_tokens)
self.grammar, used_files = load_grammar(
grammar,
self.source_path,
self.options.import_paths,
self.options.keep_all_tokens,
legacy_import=self.options.legacy_import
)
else:
assert isinstance(grammar, Grammar)
self.grammar = grammar
Expand Down
128 changes: 94 additions & 34 deletions lark/load_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pkgutil
from ast import literal_eval
from contextlib import suppress
from typing import List, Tuple, Union, Callable, Dict, Optional, Sequence, Generator
from typing import List, Tuple, Union, Callable, Dict, Optional, Sequence, Generator, cast

from .utils import bfs, logger, classify_bool, is_id_continue, is_id_start, bfs_all_unique, small_factors, OrderedSet
from .lexer import Token, TerminalDef, PatternStr, PatternRE, Pattern
Expand Down Expand Up @@ -1026,10 +1026,10 @@ def on_error(e):
return errors


def _get_mangle(prefix, aliases, base_mangle=None):
def _get_mangle(prefix, imports, base_mangle=None):
def mangle(s):
if s in aliases:
s = aliases[s]
if s in imports:
s = imports[s]
else:
if s[0] == '_':
s = '_%s__%s' % (prefix, s[1:])
Expand Down Expand Up @@ -1087,14 +1087,22 @@ class GrammarBuilder:
global_keep_all_tokens: bool
import_paths: List[Union[str, Callable]]
used_files: Dict[str, str]
legacy_import: bool

_definitions: Dict[str, Definition]
_ignore_names: List[str]

def __init__(self, global_keep_all_tokens: bool=False, import_paths: Optional[List[Union[str, Callable]]]=None, used_files: Optional[Dict[str, str]]=None) -> None:
def __init__(
self,
global_keep_all_tokens: bool=False,
import_paths: Optional[List[Union[str, Callable]]]=None,
used_files: Optional[Dict[str, str]]=None,
legacy_import: bool=False
) -> None:
self.global_keep_all_tokens = global_keep_all_tokens
self.import_paths = import_paths or []
self.used_files = used_files or {}
self.legacy_import = legacy_import

self._definitions: Dict[str, Definition] = {}
self._ignore_names: List[str] = []
Expand Down Expand Up @@ -1134,7 +1142,19 @@ def _define(self, name, is_term, exp, params=(), options=None, *, override=False
if name.startswith('__'):
self._grammar_error(is_term, 'Names starting with double-underscore are reserved (Error at {name})', name)

self._definitions[name] = Definition(is_term, exp, params, self._check_options(is_term, options))
if not override:
self._definitions[name] = Definition(is_term, exp, params, self._check_options(is_term, options))
else:
definition = self._definitions[name]
definition.is_term = is_term
definition.tree = exp
definition.params = params
definition.options = self._check_options(is_term, options)

def _link(self, name, defined_name):
assert name not in self._definitions

self._definitions[name] = self._definitions[defined_name]

def _extend(self, name, is_term, exp, params=(), options=None):
if name not in self._definitions:
Expand All @@ -1156,7 +1176,7 @@ def _extend(self, name, is_term, exp, params=(), options=None):
assert isinstance(base, Tree) and base.data == 'expansions'
base.children.insert(0, exp)

def _ignore(self, exp_or_name):
def _ignore(self, exp_or_name, dependency_mangle):
if isinstance(exp_or_name, str):
self._ignore_names.append(exp_or_name)
else:
Expand All @@ -1170,14 +1190,14 @@ def _ignore(self, exp_or_name):
item ,= item.children
if isinstance(item, Terminal):
# Keep terminal name, no need to create a new definition
self._ignore_names.append(item.name)
self._ignore_names.append(item.name if self.legacy_import else dependency_mangle(item.name))
return

name = '__IGNORE_%d'% len(self._ignore_names)
self._ignore_names.append(name)
self._definitions[name] = Definition(True, t, options=TOKEN_DEFAULT_PRIORITY)

def _unpack_import(self, stmt, grammar_name):
def _unpack_import(self, stmt, grammar_name, base_mangle: Optional[Callable[[str], str]]):
if len(stmt.children) > 1:
path_node, arg1 = stmt.children
else:
Expand All @@ -1187,21 +1207,30 @@ def _unpack_import(self, stmt, grammar_name):
if isinstance(arg1, Tree): # Multi import
dotted_path = tuple(path_node.children)
names = arg1.children
aliases = dict(zip(names, names)) # Can't have aliased multi import, so all aliases will be the same as names
if self.legacy_import:
imports = dict(zip(names, names)) # Can't have aliased multi import, so all aliases will be the same as names
else:
mangle = _get_mangle('__'.join(dotted_path), {}, base_mangle)
imports = dict(zip(names, (mangle(name) for name in names))) # Can't have aliased multi import, so all import names will just be mangled
else: # Single import
dotted_path = tuple(path_node.children[:-1])
if not dotted_path:
name ,= path_node.children
raise GrammarError("Nothing was imported from grammar `%s`" % name)
name = path_node.children[-1] # Get name from dotted path
aliases = {name.value: (arg1 or name).value} # Aliases if exist
if self.legacy_import:
imports = {name.value: (arg1 or name).value} # Aliases if exist
else:
mangle = _get_mangle('__'.join(dotted_path), {}, base_mangle)
imports = {(arg1 if arg1 else name).value: mangle(name.value)} # Alias if any, mangle otherwise


if path_node.data == 'import_lib': # Import from library
base_path = None
else: # Relative import
if grammar_name == '<string>': # Import relative to script file path if grammar is coded in script
try:
base_file = os.path.abspath(sys.modules['__main__'].__file__)
base_file = os.path.abspath(cast(str, sys.modules['__main__'].__file__))
except AttributeError:
base_file = None
else:
Expand All @@ -1214,9 +1243,9 @@ def _unpack_import(self, stmt, grammar_name):
else:
base_path = os.path.abspath(os.path.curdir)

return dotted_path, base_path, aliases
return dotted_path, base_path, imports

def _unpack_definition(self, tree, mangle):
def _unpack_definition(self, tree, mangle, dependency_mangle, imports):

if tree.data == 'rule':
name, params, exp, opts = _make_rule_tuple(*tree.children)
Expand All @@ -1228,45 +1257,64 @@ def _unpack_definition(self, tree, mangle):
exp = tree.children[-1]
is_term = True

if not self.legacy_import and name in imports:
self._grammar_error(is_term, "{Type} '{name}' defined more than once", name)

if mangle is not None:
params = tuple(mangle(p) for p in params)
name = mangle(name)

exp = _mangle_definition_tree(exp, mangle)
exp = _mangle_definition_tree(exp, mangle if self.legacy_import else dependency_mangle)
return name, is_term, exp, params, opts


def load_grammar(self, grammar_text: str, grammar_name: str="<?>", mangle: Optional[Callable[[str], str]]=None) -> None:
tree = _parse_grammar(grammar_text, grammar_name)

imports: Dict[Tuple[str, ...], Tuple[Optional[str], Dict[str, str]]] = {}
local_imports: Dict[str, str] = cast(Dict[str, str], None if self.legacy_import else {})

for stmt in tree.children:
if stmt.data == 'import':
dotted_path, base_path, aliases = self._unpack_import(stmt, grammar_name)
dotted_path, base_path, items_or_aliases = self._unpack_import(stmt, grammar_name, None if self.legacy_import else mangle)
if not self.legacy_import:
local_imports.update(items_or_aliases)
try:
import_base_path, import_aliases = imports[dotted_path]
import_base_path, prev_items_or_aliases = imports[dotted_path]
prev_items_or_aliases.update(items_or_aliases)
assert base_path == import_base_path, 'Inconsistent base_path for %s.' % '.'.join(dotted_path)
import_aliases.update(aliases)
except KeyError:
imports[dotted_path] = base_path, aliases
imports[dotted_path] = base_path, items_or_aliases

for dotted_path, (base_path, items_or_aliases) in imports.items():
if self.legacy_import:
self.do_import(dotted_path, base_path, items_or_aliases, mangle, {})
else:
self.do_import(dotted_path, base_path, local_imports, mangle, items_or_aliases)

for dotted_path, (base_path, aliases) in imports.items():
self.do_import(dotted_path, base_path, aliases, mangle)
dependency_mangle: Callable[[str], str]
if not self.legacy_import:
# if this item was imported, get the imported name (alias or mangled)
# if it's local, mangle it, unless we are in the root grammar
dependency_mangle = lambda s: local_imports[s] if s in local_imports else (mangle(s) if mangle else s)
else:
dependency_mangle = cast(Callable[[str], str], None)

for stmt in tree.children:
if stmt.data in ('term', 'rule'):
self._define(*self._unpack_definition(stmt, mangle))
self._define(*self._unpack_definition(stmt, mangle, dependency_mangle, local_imports))
elif stmt.data == 'override':
r ,= stmt.children
self._define(*self._unpack_definition(r, mangle), override=True)
name, is_term, exp, params, options = self._unpack_definition(r, mangle, dependency_mangle, {})
if not self.legacy_import:
name = dependency_mangle(name)
self._define(name, is_term, exp, params, options, override=True)
elif stmt.data == 'extend':
r ,= stmt.children
self._extend(*self._unpack_definition(r, mangle))
self._extend(*self._unpack_definition(r, mangle if self.legacy_import else dependency_mangle, dependency_mangle, {}))
elif stmt.data == 'ignore':
# if mangle is not None, we shouldn't apply ignore, since we aren't in a toplevel grammar
if mangle is None:
self._ignore(*stmt.children)
self._ignore(stmt.children[0], dependency_mangle)
elif stmt.data == 'declare':
for symbol in stmt.children:
assert isinstance(symbol, Symbol), symbol
Expand All @@ -1288,7 +1336,6 @@ def load_grammar(self, grammar_text: str, grammar_name: str="<?>", mangle: Optio
}
resolve_term_references(term_defs)


def _remove_unused(self, used):
def rule_dependencies(symbol):
try:
Expand All @@ -1303,9 +1350,16 @@ def rule_dependencies(symbol):
self._definitions = {k: v for k, v in self._definitions.items() if k in _used}


def do_import(self, dotted_path: Tuple[str, ...], base_path: Optional[str], aliases: Dict[str, str], base_mangle: Optional[Callable[[str], str]]=None) -> None:
def do_import(
self,
dotted_path: Tuple[str, ...],
base_path: Optional[str],
imports: Dict[str, str],
base_mangle: Optional[Callable[[str], str]],
imported_items: Optional[Dict[str, str]]
) -> None:
assert dotted_path
mangle = _get_mangle('__'.join(dotted_path), aliases, base_mangle)
mangle = _get_mangle('__'.join(dotted_path), imports if self.legacy_import else {}, base_mangle)
grammar_path = os.path.join(*dotted_path) + EXT
to_try = self.import_paths + ([base_path] if base_path is not None else []) + [stdlib_loader]
for source in to_try:
Expand All @@ -1324,14 +1378,20 @@ def do_import(self, dotted_path: Tuple[str, ...], base_path: Optional[str], alia
raise RuntimeError("Grammar file was changed during importing")
self.used_files[joined_path] = h

gb = GrammarBuilder(self.global_keep_all_tokens, self.import_paths, self.used_files)
gb = GrammarBuilder(self.global_keep_all_tokens, self.import_paths, self.used_files, self.legacy_import)
gb.load_grammar(text, joined_path, mangle)
gb._remove_unused(map(mangle, aliases))
gb._remove_unused(map(mangle, imports) if self.legacy_import else imports.values())
for name in gb._definitions:
if name in self._definitions:
raise GrammarError("Cannot import '%s' from '%s': Symbol already defined." % (name, grammar_path))

self._definitions.update(**gb._definitions)

if not self.legacy_import:
# linking re-imports
for name, mangled in cast(Dict[str, str], imported_items).items():
self._link(base_mangle(name) if base_mangle is not None else name, mangled)

break
else:
# Search failed. Make Python throw a nice error.
Expand Down Expand Up @@ -1406,12 +1466,12 @@ def verify_used_files(file_hashes):

def list_grammar_imports(grammar, import_paths=[]):
"Returns a list of paths to the lark grammars imported by the given grammar (recursively)"
builder = GrammarBuilder(False, import_paths)
builder = GrammarBuilder(False, import_paths, legacy_import=False)
builder.load_grammar(grammar, '<string>')
return list(builder.used_files.keys())

def load_grammar(grammar, source, import_paths, global_keep_all_tokens):
builder = GrammarBuilder(global_keep_all_tokens, import_paths)
def load_grammar(grammar, source, import_paths, global_keep_all_tokens, legacy_import):
builder = GrammarBuilder(global_keep_all_tokens, import_paths, legacy_import=legacy_import)
builder.load_grammar(grammar, source)
return builder.build(), builder.used_files

Expand Down
4 changes: 4 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Workaround to force unittest to print out all diffs without truncation
# https://stackoverflow.com/a/61345284
import unittest
__import__('sys').modules['unittest.util']._MAX_LENGTH = 999999999
13 changes: 13 additions & 0 deletions tests/configurations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
def configurations(cases):
def decorator(f):
def inner(self):
for case in cases:
f.__name__ += f".case({case})"
f.__qualname__ += f".case({case})"
f(self, case)
inner.__name__ = f.__name__
inner.__qualname__ = f.__qualname__
return inner
return decorator

import_test = configurations(("new", "legacy"))
20 changes: 14 additions & 6 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from unittest import TestCase, main, skipIf
from .configurations import import_test

from lark import Lark, Tree, Transformer, UnexpectedInput
from lark.lexer import Lexer, Token
Expand Down Expand Up @@ -134,16 +135,23 @@ def test_inline(self):
res2 = InlineTestT().transform(Lark(g, parser="lalr", cache=True, lexer_callbacks={'NUM': append_zero}).parse(text))
assert res0 == res1 == res2 == expected

def test_imports(self):
@import_test
def test_imports(self, test_type: str):
initial = len(self.mock_fs.files)
g = """
%import .grammars.ab (startab, expr)

start: startab
"""
parser = Lark(g, parser='lalr', start='startab', cache=True, source_path=__file__)
assert len(self.mock_fs.files) == 1
parser = Lark(g, parser='lalr', start='startab', cache=True, source_path=__file__)
assert len(self.mock_fs.files) == 1
parser = Lark(g, parser='lalr', start='start', cache=True, source_path=__file__, legacy_import=(test_type == "legacy"))
assert len(self.mock_fs.files) == (initial + 1)
parser = Lark(g, parser='lalr', start='start', cache=True, source_path=__file__, legacy_import=(test_type == "legacy"))
assert len(self.mock_fs.files) == (initial + 1)
res = parser.parse("ab")
self.assertEqual(res, Tree('startab', [Tree('expr', ['a', 'b'])]))
if test_type == "new":
self.assertEqual(res, Tree(Token('RULE', 'start'), [Tree('grammars__ab__startab', [Tree('grammars__ab__expr', [Token('grammars__ab__A', 'a'), Token('grammars__ab__B', 'b')])])]))
else:
self.assertEqual(res, Tree(Token('RULE', 'start'), [Tree(Token('RULE', 'startab'), [Tree(Token('RULE', 'expr'), [Token('grammars__ab__A', 'a'), Token('grammars__ab__B', 'b')])])]))

@skipIf(regex is None, "'regex' lib not installed")
def test_recursive_pattern(self):
Expand Down
Loading