diff --git a/mktestdocs/__main__.py b/mktestdocs/__main__.py index 449c344..a6dee34 100644 --- a/mktestdocs/__main__.py +++ b/mktestdocs/__main__.py @@ -50,7 +50,7 @@ def exec_python(source): register_executor("python", exec_python) -def get_codeblock_members(*classes): +def get_codeblock_members(*classes, lang="python"): """ Grabs the docstrings of any methods of any classes that are passed in. """ @@ -61,7 +61,7 @@ def get_codeblock_members(*classes): for name, member in inspect.getmembers(cl): if member.__doc__: results.append(member) - return [m for m in results if len(grab_code_blocks(m.__doc__)) > 0] + return [m for m in results if len(grab_code_blocks(m.__doc__, lang=lang)) > 0] def check_codeblock(block, lang="python"): @@ -76,7 +76,7 @@ def check_codeblock(block, lang="python"): """ first_line = block.split("\n")[0] if lang: - if first_line[3:] != lang: + if first_line.lstrip()[3:] != lang: return "" return "\n".join(block.split("\n")[1:]) @@ -104,12 +104,14 @@ def grab_code_blocks(docstring, lang="python"): block += line + "\n" return [textwrap.dedent(c) for c in codeblocks if c != ""] + def format_docstring(docstring): """Formats docstring to be able to successfully go through dedent.""" if docstring[:1] != "\n": return f"\n {docstring}" return docstring + def check_docstring(obj, lang=""): """ Given a function, test the contents of the docstring. diff --git a/tests/test_class.py b/tests/test_class.py index 73eb385..611c80b 100644 --- a/tests/test_class.py +++ b/tests/test_class.py @@ -60,12 +60,27 @@ def hello(self): """ return self.name + def hfdocs_style(self, value): + """ + Returns value + + Example: + + ```python + from dinosaur import Dinosaur + + dino = Dinosaur() + assert dino.a(1) == 1 + ``` + """ + return value + members = get_codeblock_members(Dinosaur) def test_grab_methods(): - assert len(get_codeblock_members(Dinosaur)) == 4 + assert len(get_codeblock_members(Dinosaur)) == 5 @pytest.mark.parametrize("obj", members, ids=lambda d: d.__qualname__) diff --git a/tests/test_mktestdocs.py b/tests/test_mktestdocs.py index e841f53..1a9bd40 100644 --- a/tests/test_mktestdocs.py +++ b/tests/test_mktestdocs.py @@ -2,9 +2,10 @@ from mktestdocs import check_md_file + def test_readme(monkeypatch): test_dir = pathlib.Path(__file__).parent fpath = test_dir.parent / "README.md" monkeypatch.chdir(test_dir) - check_md_file(fpath=fpath) + check_md_file(fpath=fpath, memory=True)