1111from typing import Iterable
1212from typing import List
1313from typing import Optional
14+ from typing import Set
1415from typing import Tuple
1516
1617from parsimonious import Grammar
2324
2425CORE_GRAMMAR = r'''
2526 ws = ~r"(\s+|(\s*/\*.*\*/\s*)+)"
26- qs = ~r"\"([^\"]*)\"|'([^\']*)'|`([^\`]*)`|([ A-Za-z0-9_\-\.]+)"
27- number = ~r"[-+]?(\d*\.)?\d+(e[-+]?\d+)?"i
28- integer = ~r"-?\d+"
27+ qs = ~r"\"([^\"]*)\"|'([^\']*)'|([ A-Za-z0-9_\-\.]+)|`([^\`]+)`" ws*
28+ number = ~r"[-+]?(\d*\.)?\d+(e[-+]?\d+)?"i ws*
29+ integer = ~r"-?\d+" ws*
2930 comma = ws* "," ws*
3031 eq = ws* "=" ws*
3132 open_paren = ws* "(" ws*
3233 close_paren = ws* ")" ws*
3334 open_repeats = ws* ~r"[\(\[\{]" ws*
3435 close_repeats = ws* ~r"[\)\]\}]" ws*
3536 select = ~r"SELECT"i ws+ ~r".+" ws*
37+ table = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
38+ column = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
39+ link_name = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
40+ catalog_name = ~r"(?:([A-Za-z0-9_\-]+)|`([^\`]+)`)(?:\.(?:([A-Za-z0-9_\-]+)|`([^\`]+)`))?" ws*
3641
3742 json = ws* json_object ws*
3843 json_object = ~r"{\s*" json_members? ~r"\s*}"
6570 '<integer>' : '' ,
6671 '<number>' : '' ,
6772 '<json>' : '' ,
73+ '<table>' : '' ,
74+ '<column>' : '' ,
75+ '<catalog-name>' : '' ,
76+ '<link-name>' : '' ,
6877}
6978
7079BUILTIN_DEFAULTS = { # type: ignore
@@ -226,9 +235,13 @@ def build_syntax(grammar: str) -> str:
226235 # Split on ';' on a line by itself
227236 cmd , end = grammar .split (';' , 1 )
228237
229- rules = {}
238+ name = ''
239+ rules : Dict [str , Any ] = {}
230240 for line in end .split ('\n ' ):
231241 line = line .strip ()
242+ if line .startswith ('&' ):
243+ rules [name ] += '\n ' + line
244+ continue
232245 if not line :
233246 continue
234247 name , value = line .split ('=' , 1 )
@@ -239,10 +252,16 @@ def build_syntax(grammar: str) -> str:
239252 while re .search (r' [a-z0-9_]+\b' , cmd ):
240253 cmd = re .sub (r' ([a-z0-9_]+)\b' , functools .partial (expand_rules , rules ), cmd )
241254
255+ def add_indent (m : Any ) -> str :
256+ return ' ' + (len (m .group (1 )) * ' ' )
257+
258+ # Indent line-continuations
259+ cmd = re .sub (r'^(\&+)\s*' , add_indent , cmd , flags = re .M )
260+
242261 cmd = textwrap .dedent (cmd ).rstrip () + ';'
243- cmd = re .sub (r' +' , ' ' , cmd )
244- cmd = re .sub (r'^ ' , ' ' , cmd , flags = re . M )
245- cmd = re .sub (r'\s+,\.\.\.' , ',...' , cmd )
262+ cmd = re .sub (r'(\S) +' , r'\1 ' , cmd )
263+ cmd = re .sub (r'<comma> ' , ', ' , cmd )
264+ cmd = re .sub (r'\s+,\s*\ .\.\.' , ',...' , cmd )
246265
247266 return cmd
248267
@@ -399,9 +418,15 @@ def process_grammar(
399418 help_txt = build_help (syntax_txt , full_grammar )
400419 grammar = build_cmd (grammar )
401420
421+ # Remove line-continuations
422+ grammar = re .sub (r'\n\s*&+' , r'' , grammar )
423+
402424 # Make sure grouping characters all have whitespace around them
403425 grammar = re .sub (r' *(\[|\{|\||\}|\]) *' , r' \1 ' , grammar )
404426
427+ grammar = re .sub (r'\(' , r' open_paren ' , grammar )
428+ grammar = re .sub (r'\)' , r' close_paren ' , grammar )
429+
405430 for line in grammar .split ('\n ' ):
406431 if not line .strip ():
407432 continue
@@ -418,7 +443,7 @@ def process_grammar(
418443 sql = re .sub (r'\]\s+\[' , r' | ' , sql )
419444
420445 # Lower-case keywords and make them case-insensitive
421- sql = re .sub (r'(\b|@+)([A-Z0-9 ]+)\b' , lower_and_regex , sql )
446+ sql = re .sub (r'(\b|@+)([A-Z0-9_ ]+)\b' , lower_and_regex , sql )
422447
423448 # Convert literal strings to 'qs'
424449 sql = re .sub (r"'[^']+'" , r'qs' , sql )
@@ -461,12 +486,18 @@ def process_grammar(
461486 sql = re .sub (r'\s+ws$' , r' ws*' , sql )
462487 sql = re .sub (r'\s+ws\s+\(' , r' ws* (' , sql )
463488 sql = re .sub (r'\)\s+ws\s+' , r') ws* ' , sql )
464- sql = re .sub (r'\s+ws\s+' , r' ws+ ' , sql )
489+ sql = re .sub (r'\s+ws\s+' , r' ws* ' , sql )
465490 sql = re .sub (r'\?\s+ws\+' , r'? ws*' , sql )
466491
467492 # Remove extra ws around eq
468493 sql = re .sub (r'ws\+\s*eq\b' , r'eq' , sql )
469494
495+ # Remove optional groupings when mandatory groupings are specified
496+ sql = re .sub (r'open_paren\s+ws\*\s+open_repeats\?' , r'open_paren' , sql )
497+ sql = re .sub (r'close_repeats\?\s+ws\*\s+close_paren' , r'close_paren' , sql )
498+ sql = re .sub (r'open_paren\s+open_repeats\?' , r'open_paren' , sql )
499+ sql = re .sub (r'close_repeats\?\s+close_paren' , r'close_paren' , sql )
500+
470501 out .append (f'{ op } = { sql } ' )
471502
472503 for k , v in list (rules .items ()):
@@ -548,6 +579,7 @@ class SQLHandler(NodeVisitor):
548579
549580 def __init__ (self , connection : Connection ):
550581 self .connection = connection
582+ self ._handled : Set [str ] = set ()
551583
552584 @classmethod
553585 def compile (cls , grammar : str = '' ) -> None :
@@ -614,12 +646,16 @@ def execute(self, sql: str) -> result.FusionSQLResult:
614646 )
615647
616648 type (self ).compile ()
649+ self ._handled = set ()
617650 try :
618651 params = self .visit (type (self ).grammar .parse (sql ))
619652 for k , v in params .items ():
620653 params [k ] = self .validate_rule (k , v )
621654
622655 res = self .run (params )
656+
657+ self ._handled = set ()
658+
623659 if res is not None :
624660 res .format_results (self .connection )
625661 return res
@@ -666,16 +702,20 @@ def visit_qs(self, node: Node, visited_children: Iterable[Any]) -> Any:
666702 """Quoted strings."""
667703 if node is None :
668704 return None
669- return node .match .group (1 ) or node .match .group (2 ) or \
670- node .match .group (3 ) or node .match .group (4 )
705+ return flatten (visited_children )[0 ]
706+
707+ def visit_compound (self , node : Node , visited_children : Iterable [Any ]) -> Any :
708+ """Compound name."""
709+ print (visited_children )
710+ return flatten (visited_children )[0 ]
671711
672712 def visit_number (self , node : Node , visited_children : Iterable [Any ]) -> Any :
673713 """Numeric value."""
674- return float (node . match . group ( 0 ) )
714+ return float (flatten ( visited_children )[ 0 ] )
675715
676716 def visit_integer (self , node : Node , visited_children : Iterable [Any ]) -> Any :
677717 """Integer value."""
678- return int (node . match . group ( 0 ) )
718+ return int (flatten ( visited_children )[ 0 ] )
679719
680720 def visit_ws (self , node : Node , visited_children : Iterable [Any ]) -> Any :
681721 """Whitespace and comments."""
@@ -804,19 +844,29 @@ def generic_visit(self, node: Node, visited_children: Iterable[Any]) -> Any:
804844 if node .expr_name .endswith ('_cmd' ):
805845 final = merge_dicts (flatten (visited_children )[n_keywords :])
806846 for k , v in type (self ).rule_info .items ():
807- if k .endswith ('_cmd' ) or k .endswith ('_' ):
847+ if k .endswith ('_cmd' ) or k .endswith ('_' ) or k . startswith ( '_' ) :
808848 continue
809- if k not in final :
849+ if k not in final and k not in self . _handled :
810850 final [k ] = BUILTIN_DEFAULTS .get (k , v ['default' ])
811851 return final
812852
813853 # Filter out stray empty strings
814854 out = [x for x in flatten (visited_children )[n_keywords :] if x ]
815855
816- if repeats or len ( out ) > 1 :
817- return { node .expr_name : out }
856+ # Remove underscore prefixes from rule name
857+ key_name = re . sub ( r'^_+' , r'' , node .expr_name )
818858
819- return {node .expr_name : out [0 ] if out else True }
859+ if repeats or len (out ) > 1 :
860+ self ._handled .add (node .expr_name )
861+ # If all outputs are dicts, merge them
862+ if len (out ) > 1 and not repeats :
863+ is_dicts = [x for x in out if isinstance (x , dict )]
864+ if len (is_dicts ) == len (out ):
865+ return {key_name : merge_dicts (out )}
866+ return {key_name : out }
867+
868+ self ._handled .add (node .expr_name )
869+ return {key_name : out [0 ] if out else True }
820870
821871 if hasattr (node , 'match' ):
822872 if not visited_children and not node .match .groups ():
0 commit comments