Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ Documentation = "https://icecube.github.io/skyllh"
Issues = "https://github.com/icecube/skyllh/issues"

[project.optional-dependencies]
rust = [
"skyllh-rs>=0.1",
]
dev = [
"pre-commit",
"pytest>=9.0.0",
Expand Down
18 changes: 18 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "skyllh-rs"
version = "0.1.0"
edition = "2021"

[lib]
name = "skyllh_rs"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.22", features = ["extension-module", "abi3-py311"] }
numpy = "0.22"
ndarray = "0.16"

[profile.release]
opt-level = 3
lto = true
codegen-units = 1
22 changes: 22 additions & 0 deletions rust/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[build-system]
requires = ["maturin>=1.5,<2.0"]
build-backend = "maturin"

[project]
name = "skyllh-rs"
version = "0.1.0"
requires-python = ">=3.11"
license = {text = "GPL-3+"}
description = "Rust-accelerated kernels for the skyllh analysis framework"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Scientific/Engineering :: Physics",
]

[tool.maturin]
features = ["pyo3/extension-module"]
module-name = "skyllh_rs"

[tool.maturin.env]
PYO3_USE_ABI3_FORWARD_COMPATIBILITY = "1"
11 changes: 11 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use pyo3::prelude::*;

mod llhratio;
mod random_choice;

#[pymodule]
fn skyllh_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(llhratio::log_lambda_and_grads, m)?)?;
m.add_function(wrap_pyfunction!(random_choice::weighted_choice_indices, m)?)?;
Ok(())
}
82 changes: 82 additions & 0 deletions rust/src/llhratio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use numpy::{PyArray1, PyReadonlyArray1, PyReadonlyArray2};
use pyo3::prelude::*;

// Must match ZeroSigH0SingleDatasetTCLLHRatio._one_plus_alpha = 1e-3
const ONE_PLUS_ALPHA: f64 = 1e-3;
const ALPHA: f64 = ONE_PLUS_ALPHA - 1.0; // -0.999

/// Rust port of ZeroSigH0SingleDatasetTCLLHRatio.calculate_log_lambda_and_grads().
///
/// Returns (log_lambda, grads, nsgrad_i) where grads is the full gradient
/// vector and nsgrad_i is returned so the caller can populate _cache_nsgrad_i.
#[pyfunction]
pub fn log_lambda_and_grads<'py>(
py: Python<'py>,
n: usize,
ns: f64,
ns_pidx: usize,
p_mask: PyReadonlyArray1<bool>,
xi: PyReadonlyArray1<f64>,
dxi_dp: PyReadonlyArray2<f64>,
) -> PyResult<(f64, Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
let xi_arr = xi.as_array();
let dxi_dp_arr = dxi_dp.as_array();
let p_mask_arr = p_mask.as_array();

let n_selected = xi_arr.len();
let n_pure_bkg = n - n_selected;
let n_params = p_mask_arr.len();
let n_non_ns = dxi_dp_arr.ncols();

let mut ll_sum = 0.0f64;
let mut ng_sum = 0.0f64;
let mut nsgrad_i = vec![0.0f64; n_selected];
let mut p_grads = vec![0.0f64; n_non_ns];

// Single pass: compute log_lambda, nsgrad_i, and p_grads together so each
// event's data is touched only once and no intermediate Vec is allocated.
for i in 0..n_selected {
let xv = xi_arr[i];
let ai = ns * xv;
let (ll_i, ng_i, factor) = if ai > ALPHA {
let oo1pai = 1.0 / (1.0 + ai);
(ai.ln_1p(), xv * oo1pai, ns * oo1pai)
} else {
// Taylor expansion to avoid catastrophic cancellation near -1.
let tilde = (ai - ALPHA) / ONE_PLUS_ALPHA;
let ng = (1.0 - tilde) * xv / ONE_PLUS_ALPHA;
let f = ns * (1.0 - tilde) / ONE_PLUS_ALPHA;
(ALPHA.ln_1p() + tilde - 0.5 * tilde * tilde, ng, f)
};

ll_sum += ll_i;
ng_sum += ng_i;
nsgrad_i[i] = ng_i;

let row = dxi_dp_arr.row(i);
for (j, &d) in row.iter().enumerate() {
p_grads[j] += factor * d;
}
}

let log_lambda = ll_sum + (n_pure_bkg as f64) * (-ns / n as f64).ln_1p();

// Build grads vector.
let mut grads = vec![0.0f64; n_params];
grads[ns_pidx] = ng_sum - (n_pure_bkg as f64) / (n as f64 - ns);

// Map p_grads into the full grads vector via p_mask.
let mut p_idx = 0usize;
for (i, &in_mask) in p_mask_arr.iter().enumerate() {
if in_mask {
grads[i] = p_grads[p_idx];
p_idx += 1;
}
}

Ok((
log_lambda,
PyArray1::from_vec_bound(py, grads),
PyArray1::from_vec_bound(py, nsgrad_i),
))
}
40 changes: 40 additions & 0 deletions rust/src/random_choice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use numpy::{PyArray1, PyReadonlyArray1};
use pyo3::prelude::*;

/// Rust port of the argsort + searchsorted core of RandomChoice.__call__().
///
/// Given a pre-drawn uniform_values array (from rss.random.random(size))
/// and the pre-built cdf array, returns the selected indices into self._items.
/// The numpy RNG call stays in Python so seed reproducibility is preserved.
#[pyfunction]
pub fn weighted_choice_indices<'py>(
py: Python<'py>,
cdf: PyReadonlyArray1<f64>,
uniform_values: PyReadonlyArray1<f64>,
) -> PyResult<Bound<'py, PyArray1<usize>>> {
let cdf_sl: Vec<f64> = cdf.as_array().iter().copied().collect();
let uv_sl: Vec<f64> = uniform_values.as_array().iter().copied().collect();
let n = uv_sl.len();
let m = cdf_sl.len();

// Sort positions by ascending uniform value (matches np.argsort).
let mut idxs_of_sort: Vec<usize> = (0..n).collect();
idxs_of_sort.sort_unstable_by(|&a, &b| {
uv_sl[a].partial_cmp(&uv_sl[b]).unwrap_or(std::cmp::Ordering::Equal)
});

// Two-pointer scan over sorted (uniform_values, cdf): O(n + m) vs O(n log m)
// binary search. Both sequences are ascending, so cdf_pos only advances.
// Equivalent to np.searchsorted(cdf, u, side='right') for each u.
let mut idxs = vec![0usize; n];
let mut cdf_pos = 0usize;
for &orig_pos in idxs_of_sort.iter() {
let u = uv_sl[orig_pos];
while cdf_pos < m && cdf_sl[cdf_pos] <= u {
cdf_pos += 1;
}
idxs[orig_pos] = cdf_pos;
}

Ok(PyArray1::from_vec_bound(py, idxs))
}
13 changes: 13 additions & 0 deletions skyllh/_rs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os

if os.environ.get('SKYLLH_DISABLE_RUST', '').strip().lower() in ('1', 'true', 'yes'):
_rs = None
RUST_AVAILABLE = False
else:
try:
import skyllh_rs as _rs

RUST_AVAILABLE = True
except ImportError:
_rs = None
RUST_AVAILABLE = False
16 changes: 16 additions & 0 deletions skyllh/core/llhratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

import numpy as np

from skyllh._rs import (
RUST_AVAILABLE,
_rs,
)
from skyllh.core.config import (
HasConfig,
)
Expand Down Expand Up @@ -554,6 +558,18 @@ def calculate_log_lambda_and_grads(self, N, ns, ns_pidx, p_mask, Xi, dXi_dp):
The (N_fitparams,)-shaped numpy ndarray holding the gradient value
of log_lambda for each fit parameter.
"""
if RUST_AVAILABLE:
log_lambda, grads, nsgrad_i = _rs.log_lambda_and_grads(
N,
ns,
ns_pidx,
p_mask,
np.ascontiguousarray(Xi, dtype=np.float64),
np.ascontiguousarray(dXi_dp, dtype=np.float64),
)
self._cache_nsgrad_i = nsgrad_i
return (log_lambda, grads)

tracing = self._cfg['logging']['enable_tracing']

# Get the number of selected events.
Expand Down
8 changes: 8 additions & 0 deletions skyllh/core/random.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import numpy as np

from skyllh._rs import (
RUST_AVAILABLE,
_rs,
)
from skyllh.core.py import (
classname,
int_cast,
Expand Down Expand Up @@ -200,6 +204,10 @@ def __call__(
"""
uniform_values = rss.random.random(size)

if RUST_AVAILABLE and size >= 512:
idxs = _rs.weighted_choice_indices(self._cdf, uniform_values)
return self._items[idxs]

# The np.searchsorted function is much faster when the values are
# sorted. But we want to keep the randomness of the returned items.
idxs_of_sort = np.argsort(uniform_values)
Expand Down