From 245afd7896007da982df46dbc98f1c7e4b007eae Mon Sep 17 00:00:00 2001
From: Ben Kimock <kimockb@gmail.com>
Date: Sun, 26 Nov 2023 17:06:13 -0500
Subject: [PATCH 1/2] Sometimes return the same AllocId for a ConstAllocation

---
 .../src/interpret/eval_context.rs             | 16 +++---
 .../rustc_const_eval/src/interpret/machine.rs | 20 +++++++
 src/tools/miri/src/machine.rs                 | 52 ++++++++++++++++++-
 src/tools/miri/tests/pass/const-addrs.rs      | 38 ++++++++++++++
 4 files changed, 117 insertions(+), 9 deletions(-)
 create mode 100644 src/tools/miri/tests/pass/const-addrs.rs

diff --git a/compiler/rustc_const_eval/src/interpret/eval_context.rs b/compiler/rustc_const_eval/src/interpret/eval_context.rs
index bbebf329d2659..ecab6fa387ded 100644
--- a/compiler/rustc_const_eval/src/interpret/eval_context.rs
+++ b/compiler/rustc_const_eval/src/interpret/eval_context.rs
@@ -1138,13 +1138,15 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
         span: Option<Span>,
         layout: Option<TyAndLayout<'tcx>>,
     ) -> InterpResult<'tcx, OpTy<'tcx, M::Provenance>> {
-        let const_val = val.eval(*self.tcx, self.param_env, span).map_err(|err| {
-            // FIXME: somehow this is reachable even when POST_MONO_CHECKS is on.
-            // Are we not always populating `required_consts`?
-            err.emit_note(*self.tcx);
-            err
-        })?;
-        self.const_val_to_op(const_val, val.ty(), layout)
+        M::eval_mir_constant(self, *val, span, layout, |ecx, val, span, layout| {
+            let const_val = val.eval(*ecx.tcx, ecx.param_env, span).map_err(|err| {
+                // FIXME: somehow this is reachable even when POST_MONO_CHECKS is on.
+                // Are we not always populating `required_consts`?
+                err.emit_note(*ecx.tcx);
+                err
+            })?;
+            ecx.const_val_to_op(const_val, val.ty(), layout)
+        })
     }
 
     #[must_use]
diff --git a/compiler/rustc_const_eval/src/interpret/machine.rs b/compiler/rustc_const_eval/src/interpret/machine.rs
index 5e69965512b94..d32a29ede4982 100644
--- a/compiler/rustc_const_eval/src/interpret/machine.rs
+++ b/compiler/rustc_const_eval/src/interpret/machine.rs
@@ -13,6 +13,7 @@ use rustc_middle::query::TyCtxtAt;
 use rustc_middle::ty;
 use rustc_middle::ty::layout::TyAndLayout;
 use rustc_span::def_id::DefId;
+use rustc_span::Span;
 use rustc_target::abi::{Align, Size};
 use rustc_target::spec::abi::Abi as CallAbi;
 
@@ -510,6 +511,25 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized {
     ) -> InterpResult<'tcx> {
         Ok(())
     }
+
+    #[inline(always)]
+    fn eval_mir_constant<F>(
+        ecx: &InterpCx<'mir, 'tcx, Self>,
+        val: mir::Const<'tcx>,
+        span: Option<Span>,
+        layout: Option<TyAndLayout<'tcx>>,
+        eval: F,
+    ) -> InterpResult<'tcx, OpTy<'tcx, Self::Provenance>>
+    where
+        F: Fn(
+            &InterpCx<'mir, 'tcx, Self>,
+            mir::Const<'tcx>,
+            Option<Span>,
+            Option<TyAndLayout<'tcx>>,
+        ) -> InterpResult<'tcx, OpTy<'tcx, Self::Provenance>>,
+    {
+        eval(ecx, val, span, layout)
+    }
 }
 
 /// A lot of the flexibility above is just needed for `Miri`, but all "compile-time" machines
diff --git a/src/tools/miri/src/machine.rs b/src/tools/miri/src/machine.rs
index 4a878f4a36e6c..4cd48b2e5a3a4 100644
--- a/src/tools/miri/src/machine.rs
+++ b/src/tools/miri/src/machine.rs
@@ -3,6 +3,7 @@
 
 use std::borrow::Cow;
 use std::cell::{Cell, RefCell};
+use std::collections::hash_map::Entry;
 use std::fmt;
 use std::path::Path;
 use std::process;
@@ -10,6 +11,7 @@ use std::process;
 use either::Either;
 use rand::rngs::StdRng;
 use rand::SeedableRng;
+use rand::Rng;
 
 use rustc_ast::ast::Mutability;
 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
@@ -45,6 +47,11 @@ pub const SIGRTMIN: i32 = 34;
 /// `SIGRTMAX` - `SIGRTMIN` >= 8 (which is the value of `_POSIX_RTSIG_MAX`)
 pub const SIGRTMAX: i32 = 42;
 
+/// Each const has multiple addresses, but only this many. Since const allocations are never
+/// deallocated, choosing a new [`AllocId`] and thus base address for each evaluation would
+/// produce unbounded memory usage.
+const ADDRS_PER_CONST: usize = 16;
+
 /// Extra data stored with each stack frame
 pub struct FrameExtra<'tcx> {
     /// Extra data for the Borrow Tracker.
@@ -65,12 +72,18 @@ pub struct FrameExtra<'tcx> {
     /// optimization.
     /// This is used by `MiriMachine::current_span` and `MiriMachine::caller_span`
     pub is_user_relevant: bool,
+
+    /// We have a cache for the mapping from [`mir::Const`] to resulting [`AllocId`].
+    /// However, we don't want all frames to always get the same result, so we insert
+    /// an additional bit of "salt" into the cache key. This salt is fixed per-frame
+    /// so that within a call, a const will have a stable address.
+    salt: usize,
 }
 
 impl<'tcx> std::fmt::Debug for FrameExtra<'tcx> {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         // Omitting `timing`, it does not support `Debug`.
-        let FrameExtra { borrow_tracker, catch_unwind, timing: _, is_user_relevant: _ } = self;
+        let FrameExtra { borrow_tracker, catch_unwind, timing: _, is_user_relevant: _, salt: _ } = self;
         f.debug_struct("FrameData")
             .field("borrow_tracker", borrow_tracker)
             .field("catch_unwind", catch_unwind)
@@ -80,7 +93,7 @@ impl<'tcx> std::fmt::Debug for FrameExtra<'tcx> {
 
 impl VisitProvenance for FrameExtra<'_> {
     fn visit_provenance(&self, visit: &mut VisitWith<'_>) {
-        let FrameExtra { catch_unwind, borrow_tracker, timing: _, is_user_relevant: _ } = self;
+        let FrameExtra { catch_unwind, borrow_tracker, timing: _, is_user_relevant: _, salt: _ } = self;
 
         catch_unwind.visit_provenance(visit);
         borrow_tracker.visit_provenance(visit);
@@ -552,6 +565,11 @@ pub struct MiriMachine<'mir, 'tcx> {
     /// The spans we will use to report where an allocation was created and deallocated in
     /// diagnostics.
     pub(crate) allocation_spans: RefCell<FxHashMap<AllocId, (Span, Option<Span>)>>,
+
+    /// Maps MIR consts to their evaluated result. We combine the const with a "salt" (`usize`)
+    /// that is fixed per stack frame; this lets us have sometimes different results for the
+    /// same const while ensuring consistent results within a single call.
+    const_cache: RefCell<FxHashMap<(mir::Const<'tcx>, usize), OpTy<'tcx, Provenance>>>,
 }
 
 impl<'mir, 'tcx> MiriMachine<'mir, 'tcx> {
@@ -677,6 +695,7 @@ impl<'mir, 'tcx> MiriMachine<'mir, 'tcx> {
             stack_size,
             collect_leak_backtraces: config.collect_leak_backtraces,
             allocation_spans: RefCell::new(FxHashMap::default()),
+            const_cache: RefCell::new(FxHashMap::default()),
         }
     }
 
@@ -857,6 +876,7 @@ impl VisitProvenance for MiriMachine<'_, '_> {
             stack_size: _,
             collect_leak_backtraces: _,
             allocation_spans: _,
+            const_cache: _,
         } = self;
 
         threads.visit_provenance(visit);
@@ -1400,6 +1420,7 @@ impl<'mir, 'tcx> Machine<'mir, 'tcx> for MiriMachine<'mir, 'tcx> {
             catch_unwind: None,
             timing,
             is_user_relevant: ecx.machine.is_user_relevant(&frame),
+            salt: ecx.machine.rng.borrow_mut().gen::<usize>() % ADDRS_PER_CONST,
         };
 
         Ok(frame.with_extra(extra))
@@ -1505,4 +1526,31 @@ impl<'mir, 'tcx> Machine<'mir, 'tcx> for MiriMachine<'mir, 'tcx> {
         ecx.machine.allocation_spans.borrow_mut().insert(alloc_id, (span, None));
         Ok(())
     }
+
+    fn eval_mir_constant<F>(
+        ecx: &InterpCx<'mir, 'tcx, Self>,
+        val: mir::Const<'tcx>,
+        span: Option<Span>,
+        layout: Option<TyAndLayout<'tcx>>,
+        eval: F,
+    ) -> InterpResult<'tcx, OpTy<'tcx, Self::Provenance>>
+    where
+        F: Fn(
+            &InterpCx<'mir, 'tcx, Self>,
+            mir::Const<'tcx>,
+            Option<Span>,
+            Option<TyAndLayout<'tcx>>,
+        ) -> InterpResult<'tcx, OpTy<'tcx, Self::Provenance>>,
+    {
+        let frame = ecx.active_thread_stack().last().unwrap();
+        let mut cache = ecx.machine.const_cache.borrow_mut();
+        match cache.entry((val, frame.extra.salt)) {
+            Entry::Vacant(ve) => {
+                let op = eval(ecx, val, span, layout)?;
+                ve.insert(op.clone());
+                Ok(op)
+            }
+            Entry::Occupied(oe) => Ok(oe.get().clone()),
+        }
+    }
 }
diff --git a/src/tools/miri/tests/pass/const-addrs.rs b/src/tools/miri/tests/pass/const-addrs.rs
new file mode 100644
index 0000000000000..6c14f0b679ce8
--- /dev/null
+++ b/src/tools/miri/tests/pass/const-addrs.rs
@@ -0,0 +1,38 @@
+// The const fn interpreter creates a new AllocId every time it evaluates any const.
+// If we do that in Miri, repeatedly evaluating a const causes unbounded memory use
+// we need to keep track of the base address for that AllocId, and the allocation is never
+// deallocated.
+// In Miri we explicitly store previously-assigned AllocIds for each const and ensure
+// that we only hand out a finite number of AllocIds per const.
+// MIR inlining will put every evaluation of the const we're repeatedly evaluting into the same
+// stack frame, breaking this test.
+//@compile-flags: -Zinline-mir=no
+#![feature(strict_provenance)]
+
+const EVALS: usize = 256;
+
+use std::collections::HashSet;
+fn main() {
+    let mut addrs = HashSet::new();
+    for _ in 0..EVALS {
+        addrs.insert(const_addr());
+    }
+    // Check that the const allocation has multiple base addresses
+    assert!(addrs.len() > 1);
+    // But also that we get a limited number of unique base addresses
+    assert!(addrs.len() < EVALS);
+
+    // Check that within a call we always produce the same address
+    let mut prev = 0;
+    for iter in 0..EVALS {
+        let addr = "test".as_bytes().as_ptr().addr();
+        if iter > 0 {
+            assert_eq!(prev, addr);
+        }
+        prev = addr;
+    }
+}
+
+fn const_addr() -> usize {
+    "test".as_bytes().as_ptr().addr()
+}

From c8a675d752364a527967a6ce2be836e385df1bb2 Mon Sep 17 00:00:00 2001
From: Ben Kimock <kimockb@gmail.com>
Date: Tue, 23 Jan 2024 10:17:50 -0500
Subject: [PATCH 2/2] Add a doc comment for eval_mir_constant

Co-authored-by: Ralf Jung <post@ralfj.de>
---
 compiler/rustc_const_eval/src/interpret/machine.rs | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/compiler/rustc_const_eval/src/interpret/machine.rs b/compiler/rustc_const_eval/src/interpret/machine.rs
index d32a29ede4982..b981a1ee2ca16 100644
--- a/compiler/rustc_const_eval/src/interpret/machine.rs
+++ b/compiler/rustc_const_eval/src/interpret/machine.rs
@@ -512,6 +512,8 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized {
         Ok(())
     }
 
+    /// Evaluate the given constant. The `eval` function will do all the required evaluation,
+    /// but this hook has the chance to do some pre/postprocessing.
     #[inline(always)]
     fn eval_mir_constant<F>(
         ecx: &InterpCx<'mir, 'tcx, Self>,