diff --git a/symbol_exporter/ast_symbol_extractor.py b/symbol_exporter/ast_symbol_extractor.py index 95d9ffa..e95e8d1 100644 --- a/symbol_exporter/ast_symbol_extractor.py +++ b/symbol_exporter/ast_symbol_extractor.py @@ -6,7 +6,7 @@ # Increment when we need the database to be rebuilt (eg adding a new feature) NOT_A_DEFAULT_ARG = "~~NOT_A_DEFAULT~~" -version = "1" # must be an integer +version = "2" # must be an integer builtin_symbols = set(dir(builtins)) @@ -165,9 +165,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> Any: def visit_Name(self, node: ast.Name) -> Any: def get_symbol_name(name): - return self._symbol_in_surface_area(name) or ".".join( + complete_symbol_name = ".".join( [name] + list(reversed(self.attr_stack)) ) + return self._symbol_in_surface_area(complete_symbol_name) or complete_symbol_name name = self.aliases.get(node.id, node.id) if name in builtin_symbols: diff --git a/symbol_exporter/line_inspection.py b/symbol_exporter/line_inspection.py new file mode 100644 index 0000000..9774e39 --- /dev/null +++ b/symbol_exporter/line_inspection.py @@ -0,0 +1,67 @@ +from symbol_exporter.ast_index import get_symbol_table +from symbol_exporter.ast_symbol_extractor import SymbolType, builtin_symbols + + +def infer_filenames(module_symbols): + out = {} + for i, symbol in enumerate(module_symbols): + z = module_symbols.pop(i) + if any(m.startswith(z) for m in module_symbols): + out[z] = z.replace(".", "/") + "__init__.py" + else: + out[z] = z.replace(".", "/") + ".py" + module_symbols.insert(i, z) + return out + + +def munge_artifacts(artifacts): + out = [] + for artifact in set(artifacts) - {"*"}: + fn = artifact.rsplit("/", 1)[-1] + name, version, _ = fn.rsplit("-", 2) + out.append((name, version)) + return out + + +def line_inspection(symbols, pkg_environment): + bad_lines = [] + module_symbols = [s for s in symbols if symbols[s]["type"] == SymbolType.MODULE] + filename_by_module = infer_filenames(module_symbols) + for symbol, md in symbols.items(): + module_name = symbol + while True: + filename = filename_by_module.get(module_name) + if filename: + break + else: + module_name = module_name.rsplit(".", 1)[0] + volume = md.get("data", {}).get("symbols_in_volume", {}) + for volume_symbol in set(volume) - builtin_symbols: + # dereference shadows + effective_volume_symbol = ( + symbols.get(volume_symbol, {}) + .get("data", {}) + .get("shadows", volume_symbol) + ) + # don't bother with our own symbols + if effective_volume_symbol in symbols: + continue + # get symbol table + symbol_table = get_symbol_table( + effective_volume_symbol.partition(".")[0] + ).get("symbol table", {}) + supplying_pkgs = munge_artifacts( + symbol_table.get(effective_volume_symbol, []) + ) + # for each symbol in the group check if overlap between pkg_environment and table + symbol_in_env = set(pkg_environment) & set(supplying_pkgs) + # for those that aren't in pkg env note the line and symbol + if not symbol_in_env: + bad_lines.append( + ( + filename, + volume[volume_symbol]["line number"], + effective_volume_symbol, + ) + ) + return bad_lines diff --git a/tests/test_ast_symbol_extractor.py b/tests/test_ast_symbol_extractor.py index 42d1a2c..43e81bd 100644 --- a/tests/test_ast_symbol_extractor.py +++ b/tests/test_ast_symbol_extractor.py @@ -579,3 +579,21 @@ def b(self): "mm.A.a": {"data": {"lineno": 3}, "type": "function"}, "mm.A.b": {"data": {"lineno": 5}, "type": "function"}, } + + +def test_symbols_in_volume_names(): + code = """ + import ast + + z = [ast.Try] + """ + z = process_code_str(code) + assert z.undeclared_symbols == set() + assert z.post_process_symbols() == { + "mm": { + "data": {"symbols_in_volume": {"ast.Try": {"line number": [4]}}}, + "type": "module", + }, + "mm.ast": {"data": {"shadows": "ast"}, "type": "import"}, + "mm.z": {"data": {"lineno": 4}, "type": "constant"}, + } diff --git a/tests/test_line_inspection.py b/tests/test_line_inspection.py new file mode 100644 index 0000000..d1f37fd --- /dev/null +++ b/tests/test_line_inspection.py @@ -0,0 +1,14 @@ +from symbol_exporter.ast_symbol_extractor import SymbolType +from symbol_exporter.line_inspection import line_inspection + + +def test_line_inspection(): + symbols = { + "z": { + "data": {"symbols_in_volume": "requests.ConnectTimeout"}, + "type": SymbolType.FUNCTION, + } + } + env = [] + actual = line_inspection(symbols, env) + assert set(actual) & {("requests", "2.25.1")}