Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions extensions/pyo3/private/pyo3.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,19 @@ def _py_pyo3_library_impl(ctx):
is_windows = extension.basename.endswith(".dll")

# https://pyo3.rs/v0.26.0/building-and-distribution#manual-builds
ext = ctx.actions.declare_file("{}{}".format(
ctx.label.name,
".pyd" if is_windows else ".so",
))
# Determine the on-disk and logical Python module layout.
module_name = ctx.attr.module if ctx.attr.module else ctx.label.name

# Convert a dotted prefix (e.g. "foo.bar") into a path ("foo/bar").
if ctx.attr.module_prefix:
module_prefix_path = ctx.attr.module_prefix.replace(".", "/")
module_relpath = "{}/{}.{}".format(module_prefix_path, module_name, "pyd" if is_windows else "so")
stub_relpath = "{}/{}.pyi".format(module_prefix_path, module_name)
else:
module_relpath = "{}.{}".format(module_name, "pyd" if is_windows else "so")
stub_relpath = "{}.pyi".format(module_name)

ext = ctx.actions.declare_file(module_relpath)
ctx.actions.symlink(
output = ext,
target_file = extension,
Expand All @@ -99,10 +108,10 @@ def _py_pyo3_library_impl(ctx):

stub = None
if _stubs_enabled(ctx.attr.stubs, toolchain):
stub = ctx.actions.declare_file("{}.pyi".format(ctx.label.name))
stub = ctx.actions.declare_file(stub_relpath)

args = ctx.actions.args()
args.add(ctx.label.name, format = "--module_name=%s")
args.add(module_name, format = "--module_name=%s")
args.add(ext, format = "--module_path=%s")
args.add(stub, format = "--output=%s")
ctx.actions.run(
Expand Down Expand Up @@ -180,6 +189,12 @@ py_pyo3_library = rule(
"imports": attr.string_list(
doc = "List of import directories to be added to the `PYTHONPATH`.",
),
"module": attr.string(
doc = "The Python module name implemented by this extension.",
),
"module_prefix": attr.string(
doc = "A dotted Python package prefix for the module.",
),
"stubs": attr.int(
doc = "Whether or not to generate stubs. `-1` will default to the global config, `0` will never generate, and `1` will always generate stubs.",
default = -1,
Expand Down Expand Up @@ -218,6 +233,8 @@ def pyo3_extension(
stubs = None,
version = None,
compilation_mode = "opt",
module = None,
module_prefix = None,
**kwargs):
"""Define a PyO3 python extension module.

Expand Down Expand Up @@ -259,6 +276,8 @@ def pyo3_extension(
For more details see [rust_shared_library][rsl].
compilation_mode (str, optional): The [compilation_mode](https://bazel.build/reference/command-line-reference#flag--compilation_mode)
value to build the extension for. If set to `"current"`, the current configuration will be used.
module (str, optional): The Python module name implemented by this extension.
module_prefix (str, optional): A dotted Python package prefix for the module.
**kwargs (dict): Additional keyword arguments.
"""
tags = kwargs.pop("tags", [])
Expand Down Expand Up @@ -318,6 +337,8 @@ def pyo3_extension(
compilation_mode = compilation_mode,
stubs = stubs_int,
imports = imports,
module = module,
module_prefix = module_prefix,
tags = tags,
visibility = visibility,
**kwargs
Expand Down
17 changes: 17 additions & 0 deletions extensions/pyo3/test/module_prefix/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
load("@rules_python//python:defs.bzl", "py_test")
load("//:defs.bzl", "pyo3_extension")

pyo3_extension(
name = "module_prefix",
srcs = ["bar.rs"],
edition = "2021",
imports = ["."],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how this interacts with imports? Is imports = ["."] required? What happens if a subdirectory is added? Does it make the module not imported?

module = "bar",
module_prefix = "foo",
)

py_test(
name = "module_prefix_import_test",
srcs = ["module_prefix_import_test.py"],
deps = [":module_prefix"],
)
12 changes: 12 additions & 0 deletions extensions/pyo3/test/module_prefix/bar.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use pyo3::prelude::*;

#[pyfunction]
fn thing() -> PyResult<&'static str> {
Ok("hello from rust")
}

#[pymodule]
fn bar(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(thing, m)?)?;
Ok(())
}
19 changes: 19 additions & 0 deletions extensions/pyo3/test/module_prefix/module_prefix_import_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Tests that a pyo3 extension can be imported via a module prefix."""

import unittest

import foo.bar # type: ignore


class ModulePrefixImportTest(unittest.TestCase):
"""Test Class."""

def test_import_and_call(self) -> None:
"""Test that a pyo3 extension can be imported via a module prefix."""

result = foo.bar.thing()
self.assertEqual("hello from rust", result)


if __name__ == "__main__":
unittest.main()