From 8c728d865729bab2f3e2e35e651100e582f0addd Mon Sep 17 00:00:00 2001
From: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Date: Wed, 11 Jun 2025 19:37:54 +0200
Subject: [PATCH] .WIP model memoization

---
 pymc/backends/base.py             |  12 +-
 pymc/initial_point.py             |  14 ++-
 pymc/model/core.py                | 201 +++++++++++++++++++-----------
 pymc/pytensorf.py                 |  45 ++++---
 pymc/sampling/forward.py          |  51 ++------
 pymc/sampling/mcmc.py             |  61 +--------
 pymc/step_methods/arraystep.py    |  24 ++--
 pymc/step_methods/hmc/base_hmc.py |   8 +-
 pymc/step_methods/metropolis.py   |  22 +---
 pymc/step_methods/slicer.py       |   5 +-
 pymc/util.py                      |  12 +-
 pymc/variational/opvi.py          |   6 +-
 pymc/variational/stein.py         |   4 +-
 tests/model/test_core.py          |  18 +++
 tests/sampling/test_mcmc.py       |  16 ---
 tests/test_util.py                |   4 +-
 16 files changed, 247 insertions(+), 256 deletions(-)

diff --git a/pymc/backends/base.py b/pymc/backends/base.py
index 993acc0df4..cdbfc5c32b 100644
--- a/pymc/backends/base.py
+++ b/pymc/backends/base.py
@@ -30,11 +30,9 @@
 )
 
 import numpy as np
-import pytensor
 
 from pymc.backends.report import SamplerReport
 from pymc.model import modelcontext
-from pymc.pytensorf import compile
 from pymc.util import get_var_name
 
 logger = logging.getLogger(__name__)
@@ -171,10 +169,14 @@ def __init__(
 
         if fn is None:
             # borrow=True avoids deepcopy when inputs=output which is the case for untransformed value variables
-            fn = compile(
-                inputs=[pytensor.In(v, borrow=True) for v in model.value_vars],
-                outputs=[pytensor.Out(v, borrow=True) for v in vars],
+            fn = model.compile_fn(
+                inputs=model.value_vars,
+                outputs=vars,
                 on_unused_input="ignore",
+                random_seed=False,
+                borrow_inputs=True,
+                borrow_outputs=True,
+                wrap_point_fn=False,
             )
             fn.trust_input = True
 
diff --git a/pymc/initial_point.py b/pymc/initial_point.py
index c276a5c496..581d6269f3 100644
--- a/pymc/initial_point.py
+++ b/pymc/initial_point.py
@@ -28,9 +28,8 @@
 from pymc.pytensorf import (
     SeedSequenceSeed,
     compile,
-    find_rng_nodes,
     replace_rng_nodes,
-    reseed_rngs,
+    seed_compiled_function,
     toposort_replace,
 )
 from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name
@@ -167,7 +166,12 @@ def make_initial_point_fn(
     # Replace original rng shared variables so that we don't mess with them
     # when calling the final seeded function
     initial_values = replace_rng_nodes(initial_values)
-    func = compile(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE)
+    func = compile(
+        inputs=[],
+        outputs=initial_values,
+        mode=pytensor.compile.mode.FAST_COMPILE,
+        random_seed=False,
+    )
 
     varnames = []
     for var in model.free_RVs:
@@ -179,11 +183,9 @@ def make_initial_point_fn(
         varnames.append(name)
 
     def make_seeded_function(func):
-        rngs = find_rng_nodes(func.maker.fgraph.outputs)
-
         @functools.wraps(func)
         def inner(seed, *args, **kwargs):
-            reseed_rngs(rngs, seed)
+            seed_compiled_function(func, seed)
             values = func(*args, **kwargs)
             return dict(zip(varnames, values))
 
diff --git a/pymc/model/core.py b/pymc/model/core.py
index 469001e804..622f44066e 100644
--- a/pymc/model/core.py
+++ b/pymc/model/core.py
@@ -58,11 +58,10 @@
     SeedSequenceSeed,
     compile,
     convert_observed_data,
-    gradient,
-    hessian,
     inputvars,
     join_nonshared_inputs,
     rewrite_pregrad,
+    seed_compiled_function,
 )
 from pymc.util import (
     UNSET,
@@ -73,6 +72,8 @@
     get_transformed_name,
     get_value_vars_from_user_vars,
     get_var_name,
+    invalidates_memoize,
+    memoize,
     treedict,
     treelist,
 )
@@ -455,7 +456,8 @@ def __init__(
     ):
         self.name = self._validate_name(name)
         self.check_bounds = check_bounds
-        self._parent = model if not isinstance(model, _UnsetType) else MODEL_MANAGER.parent_context
+        self.parent = model if not isinstance(model, _UnsetType) else MODEL_MANAGER.parent_context
+        self.isroot = self.parent is not None
 
         if coords_mutable is not None:
             warnings.warn(
@@ -514,10 +516,6 @@ def get_context(
             raise TypeError("No model on context stack")
         return model
 
-    @property
-    def parent(self):
-        return self._parent
-
     @property
     def root(self):
         model = self
@@ -525,10 +523,7 @@ def root(self):
             model = model.parent
         return model
 
-    @property
-    def isroot(self):
-        return self.parent is None
-
+    @memoize
     def logp_dlogp_function(
         self,
         grad_vars=None,
@@ -574,7 +569,6 @@ def logp_dlogp_function(
             grad_vars,
             extra_vars_and_values,
             model=self,
-            initial_point=initial_point,
             ravel_inputs=ravel_inputs,
             **kwargs,
         )
@@ -641,6 +635,7 @@ def compile_d2logp(
             **compile_kwargs,
         )
 
+    @memoize
     def logp(
         self,
         vars: Variable | Sequence[Variable] | None = None,
@@ -720,18 +715,21 @@ def logp(
         logp_scalar.name = logp_scalar_name
         return logp_scalar
 
+    @memoize
     def dlogp(
         self,
         vars: Variable | Sequence[Variable] | None = None,
         jacobian: bool = True,
-    ) -> Variable:
+        ravel_outputs: bool = True,
+        return_logp: bool = False,
+    ) -> list[Variable] | Variable | tuple[Variable, list[Variable] | Variable]:
         """Gradient of the models log-probability w.r.t. ``vars``.
 
         Parameters
         ----------
-        vars : list of random variables or potential terms, optional
-            Compute the gradient with respect to those variables. If None, use all
-            free and observed random variables, as well as potential terms in model.
+        vars : list of random variables, optional
+            Compute the gradient with respect to those variables.
+            If None, consider all continuous free variables.
         jacobian : bool
             Whether to include jacobian terms in logprob graph. Defaults to True.
 
@@ -740,7 +738,7 @@ def dlogp(
         dlogp graph
         """
         if vars is None:
-            value_vars = None
+            value_vars = self.continuous_value_vars
         else:
             if not isinstance(vars, list | tuple):
                 vars = [vars]
@@ -757,21 +755,27 @@ def dlogp(
 
         cost = self.logp(jacobian=jacobian)
         cost = rewrite_pregrad(cost)
-        return gradient(cost, value_vars)
-
+        gradient = pt.grad(cost, value_vars)
+        if ravel_outputs:
+            gradient = pt.concatenate([g.reshape(-1) for g in gradient], axis=0)
+        if return_logp:
+            return cost, gradient
+        return gradient
+
+    @memoize
     def d2logp(
         self,
         vars: Variable | Sequence[Variable] | None = None,
         jacobian: bool = True,
-        negate_output=True,
+        negate_output: bool | None = None,
     ) -> Variable:
-        """Hessian of the models log-probability w.r.t. ``vars``.
+        """Hessian of the models log-probability w.r.t. the flattened vector of ``vars``.
 
         Parameters
         ----------
-        vars : list of random variables or potential terms, optional
-            Compute the gradient with respect to those variables. If None, use all
-            free and observed random variables, as well as potential terms in model.
+        vars : list of random variables, optional
+            Compute the hessian with respect to those variables.
+            If None, consider all continuous free variables.
         jacobian : bool
             Whether to include jacobian terms in logprob graph. Defaults to True.
 
@@ -780,7 +784,7 @@ def d2logp(
         d²logp graph
         """
         if vars is None:
-            value_vars = None
+            value_vars = self.continuous_value_vars
         else:
             if not isinstance(vars, list | tuple):
                 vars = [vars]
@@ -795,9 +799,26 @@ def d2logp(
                         f"Requested variable {var} not found among the model variables"
                     )
 
-        cost = self.logp(jacobian=jacobian)
-        cost = rewrite_pregrad(cost)
-        return hessian(cost, value_vars, negate_output=negate_output)
+        grad = self.dlogp(
+            vars=[self.values_to_rvs[value] for value in value_vars],
+            jacobian=jacobian,
+            ravel_outputs=True,
+        )
+        hess = jacobian(grad, value_vars, vectorize=True)
+        if negate_output is not None:
+            if negate_output:
+                warnings.warn(
+                    "negate_output is deprecated and will fail in a future release. To comply with the API change, set it to None and negate the result manually",
+                    FutureWarning,
+                )
+                hess = -hess
+            else:
+                warnings.warn(
+                    "negate_output is deprecated and will fail in a future release. To comply with the API change, set it to None. The result is not negated by default.",
+                    FutureWarning,
+                )
+
+        return hess
 
     @property
     def datalogp(self) -> Variable:
@@ -812,6 +833,9 @@ def varlogp(self) -> Variable:
     @property
     def varlogp_nojac(self) -> Variable:
         """PyTensor scalar of log-probability of the unobserved random variables (excluding deterministic) without jacobian term."""
+        warnings.warn(
+            "varlogp_nojac is deprecated, use `model.logp(vars=self.free_RVs, jacobian=False)`"
+        )
         return self.logp(vars=self.free_RVs, jacobian=False)
 
     @property
@@ -824,11 +848,7 @@ def potentiallogp(self) -> Variable:
         """PyTensor scalar of log-probability of the Potential terms."""
         # Convert random variables in Potential expression into their log-likelihood
         # inputs and apply their transforms, if any
-        potentials = self.replace_rvs_by_values(self.potentials)
-        if potentials:
-            return pt.sum([pt.sum(factor) for factor in potentials])
-        else:
-            return pt.constant(0.0)
+        return self.logp(vars=self.potentials)
 
     @property
     def value_vars(self):
@@ -903,6 +923,10 @@ def dim_lengths(self) -> dict[str, TensorVariable]:
         return self._dim_lengths
 
     def shape_from_dims(self, dims):
+        warnings.warn(
+            "model.shape_from_dims is deprecated and will be removed in a future release",
+            FutureWarning,
+        )
         shape = []
         if len(set(dims)) != len(dims):
             raise ValueError("Can not contain the same dimension name twice.")
@@ -917,6 +941,7 @@ def shape_from_dims(self, dims):
             shape.extend(np.shape(self.coords[dim]))
         return tuple(shape)
 
+    @invalidates_memoize
     def add_coord(
         self,
         name: str,
@@ -977,6 +1002,7 @@ def add_coord(
         self._dim_lengths[name] = length
         self._coords[name] = values
 
+    @invalidates_memoize
     def add_coords(
         self,
         coords: dict[str, Sequence | None],
@@ -991,6 +1017,7 @@ def add_coords(
         for name, values in coords.items():
             self.add_coord(name, values, length=lengths.get(name, None))
 
+    @invalidates_memoize
     def set_dim(self, name: str, new_length: int, coord_values: Sequence | None = None):
         """Update a mutable dimension.
 
@@ -1026,6 +1053,10 @@ def set_dim(self, name: str, new_length: int, coord_values: Sequence | None = No
         dim_length.set_value(new_length)
         return
 
+    @memoize
+    def _make_initial_point(self):
+        return make_initial_point_fn(model=self, return_transformed=True)
+
     def initial_point(self, random_seed: SeedSequenceSeed = None) -> dict[str, np.ndarray]:
         """Compute the initial point of the model.
 
@@ -1039,9 +1070,10 @@ def initial_point(self, random_seed: SeedSequenceSeed = None) -> dict[str, np.nd
         ip : dict of {str : array_like}
             Maps names of transformed variables to numeric initial values in the transformed space.
         """
-        fn = make_initial_point_fn(model=self, return_transformed=True)
+        fn = self._make_initial_point()
         return Point(fn(random_seed), model=self)
 
+    @invalidates_memoize
     def set_initval(self, rv_var, initval):
         """Set an initial value (strategy) for a random variable."""
         if initval is not None and not isinstance(initval, Variable | str):
@@ -1050,6 +1082,7 @@ def set_initval(self, rv_var, initval):
 
         self.rvs_to_initial_values[rv_var] = initval
 
+    @invalidates_memoize
     def set_data(
         self,
         name: str,
@@ -1185,6 +1218,7 @@ def set_data(
 
         shared_object.set_value(values)
 
+    @invalidates_memoize
     def register_rv(
         self,
         rv_var: RandomVariable,
@@ -1263,6 +1297,7 @@ def register_rv(
 
         return rv_var
 
+    @invalidates_memoize
     def make_obs_var(
         self,
         rv_var: TensorVariable,
@@ -1362,6 +1397,7 @@ def make_obs_var(
 
         return rv_var
 
+    @invalidates_memoize
     def create_value_var(
         self,
         rv_var: TensorVariable,
@@ -1444,11 +1480,13 @@ def create_value_var(
 
         return value_var
 
+    @invalidates_memoize
     def register_data_var(self, data, dims=None):
         """Register a data variable with the model."""
         self.data_vars.append(data)
         self.add_named_variable(data, dims=dims)
 
+    @invalidates_memoize
     def add_named_variable(self, var, dims: tuple[str | None, ...] | None = None):
         """Add a random graph variable to the named variables of the model.
 
@@ -1566,6 +1604,7 @@ def copy(self):
 
         return clone_model(self)
 
+    @memoize
     def replace_rvs_by_values(
         self,
         graphs: Sequence[TensorVariable],
@@ -1612,6 +1651,40 @@ def compile_fn(
         **kwargs,
     ) -> Function: ...
 
+    @memoize
+    def _compile_fn(
+        self,
+        outs: Variable | Sequence[Variable],
+        *,
+        inputs: Sequence[Variable] | None = None,
+        mode=None,
+        borrow_inputs=False,
+        borrow_outputs=False,
+        **kwargs,
+    ) -> PointFunc | Function:
+        if inputs is None:
+            inputs = inputvars(outs)
+
+        if borrow_inputs:
+            inputs = [pytensor.In(inp, borrow=True) for inp in inputs]
+
+        if borrow_outputs:
+            if isinstance(outs, list | tuple):
+                outs = [pytensor.Out(o, borrow=True) for o in outs]
+            else:
+                outs = pytensor.Out(outs, borrow=True)
+
+        with self:
+            return compile(
+                inputs,
+                outs,
+                allow_input_downcast=True,
+                accept_inplace=True,
+                mode=mode,
+                random_seed=False,
+                **kwargs,
+            )
+
     def compile_fn(
         self,
         outs: Variable | Sequence[Variable],
@@ -1619,6 +1692,8 @@ def compile_fn(
         inputs: Sequence[Variable] | None = None,
         mode=None,
         point_fn: bool = True,
+        borrow_inputs: bool = False,
+        borrow_outputs: bool = False,
         **kwargs,
     ) -> PointFunc | Function:
         """Compiles a PyTensor function.
@@ -1641,21 +1716,19 @@ def compile_fn(
         -------
         Compiled PyTensor function
         """
-        if inputs is None:
-            inputs = inputvars(outs)
-
-        with self:
-            fn = compile(
-                inputs,
-                outs,
-                allow_input_downcast=True,
-                accept_inplace=True,
-                mode=mode,
-                **kwargs,
-            )
-
+        random_seed = kwargs.pop("random_seed", None)
+        fn = self._compile_fn(
+            outs,
+            inputs=inputs,
+            mode=mode,
+            point_fn=point_fn,
+            borrow_inputs=borrow_inputs,
+            borrow_outputs=borrow_outputs,
+            **kwargs,
+        )
+        seed_compiled_function(fn, random_seed)
         if point_fn:
-            return PointFunc(fn)
+            fn = PointFunc(fn)
         return fn
 
     def profile(
@@ -1695,29 +1768,13 @@ def profile(
         )
         if point is None:
             point = self.initial_point()
+        point_values = point.values()
 
         for _ in range(n):
-            f(**point)
+            f(*point_values)
 
         return f.profile
 
-    def update_start_vals(self, a: dict[str, np.ndarray], b: dict[str, np.ndarray]):
-        r"""Update point `a` with `b`, without overwriting existing keys.
-
-        Values specified for transformed variables in `a` will be recomputed
-        conditional on the values of `b` and stored in `b`.
-
-        Parameters
-        ----------
-        a : dict
-
-        b : dict
-        """
-        raise FutureWarning(
-            "The `Model.update_start_vals` method was removed."
-            " To change initial values you may set the items of `Model.initial_values` directly."
-        )
-
     def eval_rv_shapes(self) -> dict[str, tuple[int, ...]]:
         """Evaluate shapes of untransformed AND transformed free variables.
 
@@ -1795,7 +1852,7 @@ def check_start_vals(self, start, **kwargs):
                     "You can call `model.debug()` for more details."
                 )
 
-    def point_logps(self, point=None, round_vals=2, **kwargs):
+    def point_logps(self, point=None, round_vals=2):
         """Compute the log probability of `point` for all random variables in the model.
 
         Parameters
@@ -1817,12 +1874,12 @@ def point_logps(self, point=None, round_vals=2, **kwargs):
             point = self.initial_point()
 
         factors = self.basic_RVs + self.potentials
-        factor_logps_fn = [pt.sum(factor) for factor in self.logp(factors, sum=False)]
+        factor_logps_fn = self.compile_logp(factors, sum=False)
         return {
-            factor.name: np.round(np.asarray(factor_logp), round_vals)
+            factor.name: np.round(np.asarray(factor_logp.sum()), round_vals)
             for factor, factor_logp in zip(
                 factors,
-                self.compile_fn(factor_logps_fn, **kwargs)(point),
+                factor_logps_fn(point),
             )
         }
 
@@ -2167,6 +2224,10 @@ def compile_fn(
     -------
     Compiled PyTensor function
     """
+    warnings.warn(
+        "compile_fn is deprecated. Use `model.compile_fn` or `pytensorf.compile` instead.",
+        FutureWarning,
+    )
     model = modelcontext(model)
     return model.compile_fn(
         outs,
diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py
index 78eb3f7bbc..432561054e 100644
--- a/pymc/pytensorf.py
+++ b/pymc/pytensorf.py
@@ -36,6 +36,7 @@
     walk,
 )
 from pytensor.graph.fg import FunctionGraph, Output
+from pytensor.link.jax import JAXLinker
 from pytensor.scalar.basic import Cast
 from pytensor.scan.op import Scan
 from pytensor.tensor.basic import _as_tensor_variable
@@ -50,7 +51,7 @@
 from pytensor.tensor.variable import TensorVariable
 
 from pymc.exceptions import NotConstantValueError
-from pymc.util import makeiter
+from pymc.util import _get_seeds_per_chain, makeiter
 from pymc.vartypes import continuous_types, isgenerator, typefilter
 
 PotentialShapeType = int | np.ndarray | Sequence[int | Variable] | TensorVariable
@@ -59,7 +60,6 @@
 __all__ = [
     "CallableTensor",
     "compile",
-    "compile_pymc",
     "cont_inputs",
     "convert_data",
     "convert_observed_data",
@@ -424,10 +424,11 @@ def make_shared_replacements(point, vars, model):
     -------
     Dict of variable -> new shared variable
     """
-    othervars = set(model.value_vars) - set(vars)
+    vars_set = set(vars)
     return {
         var: pytensor.shared(point[var.name], var.name + "_shared", shape=var.type.shape)
-        for var in othervars
+        for var in model.value_vars
+        if var not in vars_set
     }
 
 
@@ -686,6 +687,27 @@ def reseed_rngs(
         rng.set_value(np.random.Generator(bit_generator), borrow=True)
 
 
+def seed_compiled_function(function, seed: SeedSequenceSeed):
+    rng_variables = [
+        inp
+        for inp in function.maker.fgraph.inputs
+        if isinstance(inp.type, RandomGeneratorSharedVariable)
+    ]
+    if rng_variables:
+        if isinstance(function.maker.linker, JAXLinker):
+            import jax
+
+            (int_seed,) = _get_seeds_per_chain(seed, 1)
+            rng_values = jax.random.split(jax.random.key(int_seed), len(rng_variables))
+        else:
+            rng_values = [
+                np.random.Generator(np.random.PCG64(sub_seed))
+                for sub_seed in np.random.SeedSequence(seed).spawn(len(rng_variables))
+            ]
+        for rng_variable, rng_value in zip(rng_variables, rng_values):
+            rng_variable.set_value(rng_value, borrow=True)
+
+
 def collect_default_updates_inner_fgraph(node: Apply) -> dict[Variable, Variable]:
     """Collect default updates from node with inner fgraph."""
     op = node.op
@@ -877,7 +899,7 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
 def compile(
     inputs,
     outputs,
-    random_seed: SeedSequenceSeed = None,
+    random_seed: SeedSequenceSeed | bool = None,
     mode=None,
     **kwargs,
 ) -> Function:
@@ -926,8 +948,9 @@ def compile(
 
     # We always reseed random variables as this provides RNGs with no chances of collision
     if rng_updates:
-        rngs = cast(list[SharedVariable], list(rng_updates))
-        reseed_rngs(rngs, random_seed)
+        if random_seed is not False:
+            rngs = cast(list[SharedVariable], list(rng_updates))
+            reseed_rngs(rngs, random_seed)
 
     # If called inside a model context, see if check_bounds flag is set to False
     try:
@@ -954,14 +977,6 @@ def compile(
     return pytensor_function
 
 
-def compile_pymc(*args, **kwargs):
-    warnings.warn(
-        "compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC",
-        FutureWarning,
-    )
-    return compile(*args, **kwargs)
-
-
 def constant_fold(
     xs: Sequence[TensorVariable], raise_not_constant: bool = True
 ) -> tuple[np.ndarray | Variable, ...]:
diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py
index b1f9c39895..c82dfa94f6 100644
--- a/pymc/sampling/forward.py
+++ b/pymc/sampling/forward.py
@@ -30,9 +30,7 @@
 import xarray
 
 from arviz import InferenceData
-from pytensor import tensor as pt
 from pytensor.graph.basic import (
-    Apply,
     Constant,
     Variable,
     ancestors,
@@ -60,7 +58,6 @@
     _get_seeds_per_chain,
     default_progress_theme,
     get_default_varnames,
-    point_wrapper,
 )
 
 __all__ = (
@@ -113,9 +110,9 @@ def compile_forward_sampling_function(
     outputs: list[Variable],
     vars_in_trace: list[Variable],
     basic_rvs: list[Variable] | None = None,
-    givens_dict: dict[Variable, Any] | None = None,
     constant_data: dict[str, np.ndarray] | None = None,
     constant_coords: set[str] | None = None,
+    model=None,
     **kwargs,
 ) -> tuple[Callable[..., np.ndarray | list[np.ndarray]], set[Variable]]:
     """Compile a function to draw samples, conditioned on the values of some variables.
@@ -131,7 +128,6 @@ def compile_forward_sampling_function(
     - Variables in the outputs list
     - ``SharedVariable`` instances that are not ``RandomGeneratorSharedVariable``, and whose values changed with respect to what they were at inference time
     - Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list
-    - Variables that are keys in the ``givens_dict``
     - Variables that have volatile inputs
 
     Concretely, this function can be used to compile a function to sample from the
@@ -142,15 +138,6 @@ def compile_forward_sampling_function(
     ignored and new values will be computed (in the case of deterministics and potentials) or
     sampled (in the case of random variables).
 
-    This function also enables a way to impute values for any variable in the computational
-    graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
-    to set the ``givens`` argument of the pytensor function compilation. This will essentially
-    replace a node in the computational graph with any other expression that has the same
-    type as the desired node. Passing variables in the givens_dict is considered an intervention
-    that might lead to different variable values from those that could have been seen during
-    inference, as such, **any variable that is passed in the ``givens_dict`` will be considered
-    volatile**.
-
     Parameters
     ----------
     outputs : List[pytensor.graph.basic.Variable]
@@ -163,10 +150,6 @@ def compile_forward_sampling_function(
         be considered as random variable instances. This includes variables that have
         a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or
         Censored distributions.
-    givens_dict : Optional[Dict[pytensor.graph.basic.Variable, Any]]
-        A dictionary that maps tensor variables to the values that should be used to replace them
-        in the compiled function. The types of the key and value should match or an error will be
-        raised during compilation.
     constant_data : Optional[Dict[str, numpy.ndarray]]
         A dictionary that maps the names of ``Data`` instances to their
         corresponding values at inference time. If a model was created with ``Data``, these
@@ -195,9 +178,6 @@ def compile_forward_sampling_function(
         Set of all basic_rvs that were considered volatile and will be resampled when
         the function is evaluated
     """
-    if givens_dict is None:
-        givens_dict = {}
-
     if basic_rvs is None:
         basic_rvs = []
 
@@ -226,7 +206,6 @@ def shared_value_matches(var):
     for node in nodes:
         if (
             node in fg.outputs
-            or node in givens_dict
             or (  # SharedVariables, except RandomState/Generators
                 isinstance(node, SharedVariable)
                 and not isinstance(node, RandomGeneratorSharedVariable)
@@ -263,20 +242,15 @@ def expand(node):
     # the entire graph
     list(walk(fg.outputs, expand))
 
-    # Populate the givens list
-    givens = [
-        (
-            node,
-            value
-            if isinstance(value, Variable | Apply)
-            else pt.constant(value, dtype=getattr(node, "dtype", None), name=node.name),
-        )
-        for node, value in givens_dict.items()
-    ]
+    if model is None:
+        fn = compile(inputs, fg.outputs, on_unused_input="ignore", **kwargs)
+    else:
+        # Go through model to cache function
+        fn = model.compile_fn(fg.outputs, inputs=inputs, on_unused_input="ignore", **kwargs)
 
     return (
-        compile(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs),
-        set(basic_rvs) & (volatile_nodes - set(givens_dict)),  # Basic RVs that will be resampled
+        fn,
+        set(basic_rvs) & volatile_nodes,  # Basic RVs that will be resampled
     )
 
 
@@ -467,8 +441,6 @@ def sample_prior_predictive(
         vars_to_sample,
         vars_in_trace=[],
         basic_rvs=model.basic_RVs,
-        givens_dict=None,
-        random_seed=random_seed,
         **compile_kwargs,
     )
 
@@ -901,17 +873,16 @@ def sample_posterior_predictive(
     compile_kwargs.setdefault("allow_input_downcast", True)
     compile_kwargs.setdefault("accept_inplace", True)
 
-    _sampler_fn, volatile_basic_rvs = compile_forward_sampling_function(
+    sampler_fn, volatile_basic_rvs = compile_forward_sampling_function(
         outputs=vars_to_sample,
         vars_in_trace=vars_in_trace,
         basic_rvs=model.basic_RVs,
-        givens_dict=None,
-        random_seed=random_seed,
         constant_data=constant_data,
         constant_coords=constant_coords,
         **compile_kwargs,
+        random_seed=random_seed,
+        on_unused_input="ignore",
     )
-    sampler_fn = point_wrapper(_sampler_fn)
     # All model variables have a name, but mypy does not know this
     _log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}")  # type: ignore[arg-type, return-value]
     ppc_trace_t = _DefaultTrace(samples)
diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py
index f2dfa6e9c2..6e1bf1fcc0 100644
--- a/pymc/sampling/mcmc.py
+++ b/pymc/sampling/mcmc.py
@@ -109,7 +109,6 @@ def instantiate_steppers(
     selected_steps: Mapping[type[BlockedStep], list[Any]],
     *,
     step_kwargs: dict[str, dict] | None = None,
-    initial_point: PointType | None = None,
     compile_kwargs: dict | None = None,
 ) -> Step | list[Step]:
     """Instantiate steppers assigned to the model variables.
@@ -141,9 +140,6 @@ def instantiate_steppers(
 
     used_keys = set()
     if selected_steps:
-        if initial_point is None:
-            initial_point = model.initial_point()
-
         for step_class, vars in selected_steps.items():
             if vars:
                 name = getattr(step_class, "name")
@@ -152,7 +148,6 @@ def instantiate_steppers(
                 step = step_class(
                     vars=vars,
                     model=model,
-                    initial_point=initial_point,
                     compile_kwargs=compile_kwargs,
                     **kwargs,
                 )
@@ -769,20 +764,8 @@ def joined_blas_limiter():
     rngs = get_random_generator(random_seed).spawn(chains)
     random_seed_list = [rng.integers(2**30) for rng in rngs]
 
-    if not discard_tuned_samples and not return_inferencedata and not isinstance(trace, ZarrTrace):
-        warnings.warn(
-            "Tuning samples will be included in the returned `MultiTrace` object, which can lead to"
-            " complications in your downstream analysis. Please consider to switch to `InferenceData`:\n"
-            "`pm.sample(..., return_inferencedata=True)`",
-            UserWarning,
-            stacklevel=2,
-        )
-
     # small trace warning
-    if draws == 0:
-        msg = "Tuning was enabled throughout the whole trace."
-        _log.warning(msg)
-    elif draws < 100:
+    if 0 < draws < 100:
         msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
         _log.warning(msg)
 
@@ -858,7 +841,6 @@ def joined_blas_limiter():
             steps=provided_steps,
             selected_steps=selected_steps,
             step_kwargs=kwargs,
-            initial_point=initial_points[0],
             compile_kwargs=compile_kwargs,
         )
         if isinstance(step, list):
@@ -1097,31 +1079,6 @@ def _sample_return(
     return mtrace
 
 
-def _check_start_shape(model, start: PointType):
-    """Check that the prior evaluations and initial points have identical shapes.
-
-    Parameters
-    ----------
-    model : pm.Model
-        The current model on context.
-    start : dict
-        The complete dictionary mapping (transformed) variable names to numeric initial values.
-    """
-    e = ""
-    try:
-        actual_shapes = model.eval_rv_shapes()
-    except NotImplementedError as ex:
-        warnings.warn(f"Unable to validate shapes: {ex.args[0]}", UserWarning)
-        return
-    for name, sval in start.items():
-        ashape = actual_shapes.get(name)
-        sshape = np.shape(sval)
-        if ashape != tuple(sshape):
-            e += f"\nExpected shape {ashape} for var '{name}', got: {sshape}"
-    if e != "":
-        raise ValueError(f"Bad shape in start point:{e}")
-
-
 def _sample_many(
     *,
     draws: int,
@@ -1595,12 +1552,13 @@ def init_nuts(
             pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
         ]
 
-    logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs)
-    logp_dlogp_func.trust_input = True
+    logp_dlogp_func = model.logp_dlogp_function(
+        ravel_inputs=True, trust_input=True, **compile_kwargs
+    )
 
     def model_logp_fn(ip: PointType) -> np.ndarray:
         q, _ = DictToArrayBijection.map(ip)
-        return logp_dlogp_func([q], extra_vars={})[0]
+        return logp_dlogp_func(q)[0]
 
     initial_points = _init_jitter(
         model,
@@ -1726,14 +1684,7 @@ def model_logp_fn(ip: PointType) -> np.ndarray:
     else:
         raise ValueError(f"Unknown initializer: {init}.")
 
-    step = pm.NUTS(
-        potential=potential,
-        model=model,
-        rng=random_seed_list[0],
-        initial_point=initial_points[0],
-        logp_dlogp_func=logp_dlogp_func,
-        **kwargs,
-    )
+    step = pm.NUTS(potential=potential, model=model, rng=random_seed_list[0], **kwargs)
 
     # Filter deterministics from initial_points
     value_var_names = [var.name for var in model.value_vars]
diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py
index 0c20e09a47..9367fa78de 100644
--- a/pymc/step_methods/arraystep.py
+++ b/pymc/step_methods/arraystep.py
@@ -179,26 +179,22 @@ def __init__(
         model=None,
         blocked: bool = True,
         dtype=None,
-        logp_dlogp_func=None,
         rng: RandomGenerator = None,
-        initial_point: PointType | None = None,
         compile_kwargs: dict | None = None,
         **pytensor_kwargs,
     ):
         model = modelcontext(model)
 
-        if logp_dlogp_func is None:
-            if compile_kwargs is None:
-                compile_kwargs = {}
-            logp_dlogp_func = model.logp_dlogp_function(
-                vars,
-                dtype=dtype,
-                ravel_inputs=True,
-                initial_point=initial_point,
-                **compile_kwargs,
-                **pytensor_kwargs,
-            )
-            logp_dlogp_func.trust_input = True
+        if compile_kwargs is None:
+            compile_kwargs = {}
+        logp_dlogp_func = model.logp_dlogp_function(
+            vars,
+            dtype=dtype,
+            ravel_inputs=True,
+            trust_input=True,
+            **compile_kwargs,
+            **pytensor_kwargs,
+        )
 
         self._logp_dlogp_func = logp_dlogp_func
 
diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py
index e8c96e8c4b..9d0a171d4f 100644
--- a/pymc/step_methods/hmc/base_hmc.py
+++ b/pymc/step_methods/hmc/base_hmc.py
@@ -22,7 +22,7 @@
 
 import numpy as np
 
-from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType
+from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType
 from pymc.exceptions import SamplingError
 from pymc.model import Point, modelcontext
 from pymc.pytensorf import floatX
@@ -98,7 +98,6 @@ def __init__(
         adapt_step_size=True,
         step_rand=None,
         rng=None,
-        initial_point: PointType | None = None,
         **pytensor_kwargs,
     ):
         """Set up Hamiltonian samplers with common structures.
@@ -144,7 +143,6 @@ def __init__(
             model=self._model,
             dtype=dtype,
             rng=rng,
-            initial_point=initial_point,
             **pytensor_kwargs,
         )
 
@@ -152,9 +150,7 @@ def __init__(
         self.Emax = Emax
         self.iter_count = 0
 
-        if initial_point is None:
-            initial_point = self._model.initial_point()
-
+        initial_point = self._model.initial_point()
         nuts_vars = [initial_point[v.name] for v in vars]
         size = sum(v.size for v in nuts_vars)
 
diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py
index 70c650653d..5d78a6708e 100644
--- a/pymc/step_methods/metropolis.py
+++ b/pymc/step_methods/metropolis.py
@@ -30,7 +30,6 @@
 import pymc as pm
 
 from pymc.blocking import DictToArrayBijection, RaveledVars
-from pymc.initial_point import PointType
 from pymc.pytensorf import (
     CallableTensor,
     compile,
@@ -163,7 +162,6 @@ def __init__(
         model=None,
         mode=None,
         rng=None,
-        initial_point: PointType | None = None,
         compile_kwargs: dict | None = None,
         blocked: bool = False,
     ):
@@ -194,8 +192,7 @@ def __init__(
             :py:func:`pymc.util.get_random_generator` for more information.
         """
         model = pm.modelcontext(model)
-        if initial_point is None:
-            initial_point = model.initial_point()
+        initial_point = model.initial_point()
 
         if vars is None:
             vars = model.value_vars
@@ -466,7 +463,6 @@ def __init__(
         tune_interval=100,
         model=None,
         rng=None,
-        initial_point: PointType | None = None,
         compile_kwargs: dict | None = None,
         blocked: bool = True,
     ):
@@ -591,7 +587,6 @@ def __init__(
         transit_p=0.8,
         model=None,
         rng=None,
-        initial_point: PointType | None = None,
         compile_kwargs: dict | None = None,
         blocked: bool = True,
     ):
@@ -605,8 +600,7 @@ def __init__(
 
         vars = get_value_vars_from_user_vars(vars, model)
 
-        if initial_point is None:
-            initial_point = model.initial_point()
+        initial_point = model.initial_point()
         self.dim = sum(initial_point[v.name].size for v in vars)
 
         if order == "random":
@@ -713,7 +707,6 @@ def __init__(
         order="random",
         model=None,
         rng: RandomGenerator = None,
-        initial_point: PointType | None = None,
         compile_kwargs: dict | None = None,
         blocked: bool = True,
     ):
@@ -721,8 +714,7 @@ def __init__(
 
         vars = get_value_vars_from_user_vars(vars, model)
 
-        if initial_point is None:
-            initial_point = model.initial_point()
+        initial_point = model.initial_point()
 
         dimcats: list[tuple[int, int]] = []
         # The above variable is a list of pairs (aggregate dimension, number
@@ -948,13 +940,11 @@ def __init__(
         model=None,
         mode=None,
         rng=None,
-        initial_point: PointType | None = None,
         compile_kwargs: dict | None = None,
         blocked: bool = True,
     ):
         model = pm.modelcontext(model)
-        if initial_point is None:
-            initial_point = model.initial_point()
+        initial_point = model.initial_point()
         initial_values_size = sum(initial_point[n.name].size for n in model.value_vars)
 
         if vars is None:
@@ -1118,15 +1108,13 @@ def __init__(
         tune_interval=100,
         tune_drop_fraction: float = 0.9,
         model=None,
-        initial_point: PointType | None = None,
         compile_kwargs: dict | None = None,
         mode=None,
         rng=None,
         blocked: bool = True,
     ):
         model = pm.modelcontext(model)
-        if initial_point is None:
-            initial_point = model.initial_point()
+        initial_point = model.initial_point()
         initial_values_size = sum(initial_point[n.name].size for n in model.value_vars)
 
         if vars is None:
diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py
index 9c10acfdf4..e27e1f87a2 100644
--- a/pymc/step_methods/slicer.py
+++ b/pymc/step_methods/slicer.py
@@ -21,7 +21,6 @@
 from rich.table import Column
 
 from pymc.blocking import RaveledVars, StatsType
-from pymc.initial_point import PointType
 from pymc.model import modelcontext
 from pymc.pytensorf import compile, join_nonshared_inputs, make_shared_replacements
 from pymc.step_methods.arraystep import ArrayStepShared
@@ -88,7 +87,6 @@ def __init__(
         model=None,
         iter_limit=np.inf,
         rng=None,
-        initial_point: PointType | None = None,
         compile_kwargs: dict | None = None,
         blocked: bool = False,  # Could be true since tuning is independent across dims?
     ):
@@ -103,8 +101,7 @@ def __init__(
         else:
             vars = get_value_vars_from_user_vars(vars, model)
 
-        if initial_point is None:
-            initial_point = model.initial_point()
+        initial_point = model.initial_point()
 
         shared = make_shared_replacements(initial_point, vars, model)
         [logp], raveled_inp = join_nonshared_inputs(
diff --git a/pymc/util.py b/pymc/util.py
index 979b3beebf..fe3791f3a8 100644
--- a/pymc/util.py
+++ b/pymc/util.py
@@ -399,7 +399,7 @@ def __setstate__(self, state):
         self.__dict__.update(state)
 
 
-def locally_cachedmethod(f):
+def memoize(f):
     from collections import defaultdict
 
     def self_cache_fn(f_name):
@@ -411,6 +411,16 @@ def cf(self):
     return cachedmethod(self_cache_fn(f.__name__), key=hash_key)(f)
 
 
+def invalidates_memoize(f):
+    @functools.wraps(f)
+    def wrapper_fn(self, *args, **kwargs):
+        if cache := getattr(self, "_cache", None):
+            cache.clear()
+        return f(self, *args, **kwargs)
+
+    return wrapper_fn
+
+
 def check_dist_not_registered(dist, model=None):
     """Check that a dist is not registered in the model already."""
     from pymc.model import modelcontext
diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py
index deedfc8d9f..ce6446b110 100644
--- a/pymc/variational/opvi.py
+++ b/pymc/variational/opvi.py
@@ -82,8 +82,8 @@
     RandomState,
     WithMemoization,
     _get_seeds_per_chain,
-    locally_cachedmethod,
     makeiter,
+    memoize,
 )
 from pymc.variational.minibatch_rv import MinibatchRandomVariable, get_scaling
 from pymc.variational.updates import adagrad_window
@@ -150,12 +150,12 @@ def node_property(f):
         def wrapper(fn):
             ff = append_name(f)(fn)
             f_ = pytensor.config.change_flags(compute_test_value="off")(ff)
-            return property(locally_cachedmethod(f_))
+            return property(memoize(f_))
 
         return wrapper
     else:
         f_ = pytensor.config.change_flags(compute_test_value="off")(f)
-        return property(locally_cachedmethod(f_))
+        return property(memoize(f_))
 
 
 @pytensor.config.change_flags(compute_test_value="ignore")
diff --git a/pymc/variational/stein.py b/pymc/variational/stein.py
index 0534bb6fa4..f0ad019e18 100644
--- a/pymc/variational/stein.py
+++ b/pymc/variational/stein.py
@@ -17,7 +17,7 @@
 from pytensor.graph.replace import graph_replace
 
 from pymc.pytensorf import floatX
-from pymc.util import WithMemoization, locally_cachedmethod
+from pymc.util import WithMemoization, memoize
 from pymc.variational.opvi import node_property
 from pymc.variational.test_functions import rbf
 
@@ -93,6 +93,6 @@ def logp_norm(self):
             )
         return sized_symbolic_logp / self.approx.symbolic_normalizing_constant
 
-    @locally_cachedmethod
+    @memoize
     def _kernel(self):
         return self._kernel_f(self.input_joint_matrix)
diff --git a/tests/model/test_core.py b/tests/model/test_core.py
index b26a9d96b7..0c03f15d74 100644
--- a/tests/model/test_core.py
+++ b/tests/model/test_core.py
@@ -1855,3 +1855,21 @@ def test_guassian_process_copy_failure(self, copy_method) -> None:
             match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
         ):
             copy_method(gaussian_process_model)
+
+
+@pytest.mark.parametrize()
+def test_memoization():
+    with pm.Model() as m:
+        x = pm.Normal("x")
+        y = pm.Normal("y")
+
+    res1 = m.logp()
+    res2 = m.logp()
+    res3 = m.logp(sum=False)
+
+    res4 = m.logp()
+    assert res1 is res2
+    assert res1 is not res3
+    assert res1 is res4
+
+    m.invalidate_cache()
diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py
index 090b76130b..be69f21346 100644
--- a/tests/sampling/test_mcmc.py
+++ b/tests/sampling/test_mcmc.py
@@ -212,22 +212,6 @@ def test_reset_tuning(self):
             assert step.potential._n_samples == tune
             assert step.step_adapt._count == tune + 1
 
-    @pytest.mark.parametrize(
-        "start, error",
-        [
-            ({"x": 1}, ValueError),
-            ({"x": [1, 2, 3]}, ValueError),
-            ({"x": np.array([[1, 1], [1, 1]])}, ValueError),
-        ],
-    )
-    def test_sample_start_bad_shape(self, start, error):
-        with pytest.raises(error):
-            pm.sampling.mcmc._check_start_shape(self.model, start)
-
-    @pytest.mark.parametrize("start", [{"x": np.array([1, 1])}, {"x": [10, 10]}, {"x": [-10, -10]}])
-    def test_sample_start_good_shape(self, start):
-        pm.sampling.mcmc._check_start_shape(self.model, start)
-
     def test_sample_callback(self):
         callback = mock.Mock()
         test_cores = [1, 2]
diff --git a/tests/test_util.py b/tests/test_util.py
index 98cc168f0e..f405ea7cde 100644
--- a/tests/test_util.py
+++ b/tests/test_util.py
@@ -30,7 +30,7 @@
     get_value_vars_from_user_vars,
     hash_key,
     hashable,
-    locally_cachedmethod,
+    memoize,
 )
 
 
@@ -138,7 +138,7 @@ def some_func(x):
     assert some_func(b1) != some_func(b2)
 
     class TestClass:
-        @locally_cachedmethod
+        @memoize
         def some_method(self, x):
             return x