Skip to content

Commit 04a2eec

Browse files
Cache instantiation of canonical binder
1 parent d1d8e38 commit 04a2eec

File tree

2 files changed

+177
-20
lines changed

2 files changed

+177
-20
lines changed

compiler/rustc_infer/src/infer/canonical/instantiate.rs

Lines changed: 169 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
//! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html
88
99
use rustc_macros::extension;
10-
use rustc_middle::bug;
11-
use rustc_middle::ty::{self, FnMutDelegate, GenericArgKind, TyCtxt, TypeFoldable};
10+
use rustc_middle::ty::{
11+
self, DelayedMap, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeSuperVisitable,
12+
TypeVisitableExt, TypeVisitor,
13+
};
14+
use rustc_type_ir::TypeVisitable;
1215

1316
use crate::infer::canonical::{Canonical, CanonicalVarValues};
1417

@@ -58,23 +61,169 @@ where
5861
T: TypeFoldable<TyCtxt<'tcx>>,
5962
{
6063
if var_values.var_values.is_empty() {
61-
value
62-
} else {
63-
let delegate = FnMutDelegate {
64-
regions: &mut |br: ty::BoundRegion| match var_values[br.var].kind() {
65-
GenericArgKind::Lifetime(l) => l,
66-
r => bug!("{:?} is a region but value is {:?}", br, r),
67-
},
68-
types: &mut |bound_ty: ty::BoundTy| match var_values[bound_ty.var].kind() {
69-
GenericArgKind::Type(ty) => ty,
70-
r => bug!("{:?} is a type but value is {:?}", bound_ty, r),
71-
},
72-
consts: &mut |bound_ct: ty::BoundVar| match var_values[bound_ct].kind() {
73-
GenericArgKind::Const(ct) => ct,
74-
c => bug!("{:?} is a const but value is {:?}", bound_ct, c),
75-
},
76-
};
77-
78-
tcx.replace_escaping_bound_vars_uncached(value, delegate)
64+
return value;
7965
}
66+
67+
value.fold_with(&mut CanonicalInstantiator {
68+
tcx,
69+
current_index: ty::INNERMOST,
70+
var_values: var_values.var_values,
71+
cache: Default::default(),
72+
})
73+
}
74+
75+
/// Replaces the bound vars in a canonical binder with var values.
76+
struct CanonicalInstantiator<'tcx> {
77+
tcx: TyCtxt<'tcx>,
78+
79+
// The values that the bound vars are are being instantiated with.
80+
var_values: ty::GenericArgsRef<'tcx>,
81+
82+
/// As with `BoundVarReplacer`, represents the index of a binder *just outside*
83+
/// the ones we have visited.
84+
current_index: ty::DebruijnIndex,
85+
86+
// Instantiation is a pure function of `DebruijnIndex` and `Ty`.
87+
cache: DelayedMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>,
88+
}
89+
90+
impl<'tcx> TypeFolder<TyCtxt<'tcx>> for CanonicalInstantiator<'tcx> {
91+
fn cx(&self) -> TyCtxt<'tcx> {
92+
self.tcx
93+
}
94+
95+
fn fold_binder<T: TypeFoldable<TyCtxt<'tcx>>>(
96+
&mut self,
97+
t: ty::Binder<'tcx, T>,
98+
) -> ty::Binder<'tcx, T> {
99+
self.current_index.shift_in(1);
100+
let t = t.super_fold_with(self);
101+
self.current_index.shift_out(1);
102+
t
103+
}
104+
105+
fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
106+
match *t.kind() {
107+
ty::Bound(debruijn, bound_ty) if debruijn == self.current_index => {
108+
self.var_values[bound_ty.var.as_usize()].expect_ty()
109+
}
110+
_ => {
111+
if !t.has_vars_bound_at_or_above(self.current_index) {
112+
t
113+
} else if let Some(&t) = self.cache.get(&(self.current_index, t)) {
114+
t
115+
} else {
116+
let res = t.super_fold_with(self);
117+
assert!(self.cache.insert((self.current_index, t), res));
118+
res
119+
}
120+
}
121+
}
122+
}
123+
124+
fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
125+
match r.kind() {
126+
ty::ReBound(debruijn, br) if debruijn == self.current_index => {
127+
self.var_values[br.var.as_usize()].expect_region()
128+
}
129+
_ => r,
130+
}
131+
}
132+
133+
fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
134+
match ct.kind() {
135+
ty::ConstKind::Bound(debruijn, bound_const) if debruijn == self.current_index => {
136+
self.var_values[bound_const.as_usize()].expect_const()
137+
}
138+
_ => ct.super_fold_with(self),
139+
}
140+
}
141+
142+
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
143+
if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
144+
}
145+
146+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
147+
if !c.has_vars_bound_at_or_above(self.current_index) {
148+
return c;
149+
}
150+
151+
// Since instantiation is a function of `DebruijnIndex`, we don't want
152+
// to have to cache more copies of clauses when we're inside of binders.
153+
// Since we currently expect to only have clauses in the outermost
154+
// debruijn index, we just fold if we're inside of a binder.
155+
if self.current_index > ty::INNERMOST {
156+
return c.super_fold_with(self);
157+
}
158+
159+
// Our cache key is `(clauses, var_values)`, but we also don't care about
160+
// var values that aren't named in the clauses, since they can change without
161+
// affecting the output. Since `ParamEnv`s are cached first, we compute the
162+
// last var value that is mentioned in the clauses, and cut off the list so
163+
// that we have more hits in the cache.
164+
165+
// We also cache the computation of "highest var named by clauses" since that
166+
// is both expensive (depending on the size of the clauses) and a pure function.
167+
let index = *self
168+
.tcx
169+
.highest_var_in_clauses_cache
170+
.lock()
171+
.entry(c)
172+
.or_insert_with(|| highest_var_in_clauses(c));
173+
let c_args = &self.var_values[..=index];
174+
175+
if let Some(c) = self.tcx.clauses_cache.lock().get(&(c, c_args)) {
176+
c
177+
} else {
178+
let folded = c.super_fold_with(self);
179+
self.tcx.clauses_cache.lock().insert((c, c_args), folded);
180+
folded
181+
}
182+
}
183+
}
184+
185+
fn highest_var_in_clauses<'tcx>(c: ty::Clauses<'tcx>) -> usize {
186+
struct HighestVarInClauses {
187+
max_var: usize,
188+
current_index: ty::DebruijnIndex,
189+
}
190+
impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for HighestVarInClauses {
191+
fn visit_binder<T: TypeVisitable<TyCtxt<'tcx>>>(
192+
&mut self,
193+
t: &ty::Binder<'tcx, T>,
194+
) -> Self::Result {
195+
self.current_index.shift_in(1);
196+
let t = t.super_visit_with(self);
197+
self.current_index.shift_out(1);
198+
t
199+
}
200+
fn visit_ty(&mut self, t: Ty<'tcx>) {
201+
if let ty::Bound(debruijn, bound_ty) = *t.kind()
202+
&& debruijn == self.current_index
203+
{
204+
self.max_var = self.max_var.max(bound_ty.var.as_usize());
205+
} else if t.has_vars_bound_at_or_above(self.current_index) {
206+
t.super_visit_with(self);
207+
}
208+
}
209+
fn visit_region(&mut self, r: ty::Region<'tcx>) {
210+
if let ty::ReBound(debruijn, bound_region) = r.kind()
211+
&& debruijn == self.current_index
212+
{
213+
self.max_var = self.max_var.max(bound_region.var.as_usize());
214+
}
215+
}
216+
fn visit_const(&mut self, ct: ty::Const<'tcx>) {
217+
if let ty::ConstKind::Bound(debruijn, bound_const) = ct.kind()
218+
&& debruijn == self.current_index
219+
{
220+
self.max_var = self.max_var.max(bound_const.as_usize());
221+
} else if ct.has_vars_bound_at_or_above(self.current_index) {
222+
ct.super_visit_with(self);
223+
}
224+
}
225+
}
226+
let mut visitor = HighestVarInClauses { max_var: 0, current_index: ty::INNERMOST };
227+
c.visit_with(&mut visitor);
228+
visitor.max_var
80229
}

compiler/rustc_middle/src/ty/context.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,12 @@ pub struct GlobalCtxt<'tcx> {
14791479

14801480
pub canonical_param_env_cache: CanonicalParamEnvCache<'tcx>,
14811481

1482+
/// Caches the index of the highest bound var in clauses in a canonical binder.
1483+
pub highest_var_in_clauses_cache: Lock<FxHashMap<ty::Clauses<'tcx>, usize>>,
1484+
/// Caches the instantiation of a canonical binder given a set of args.
1485+
pub clauses_cache:
1486+
Lock<FxHashMap<(ty::Clauses<'tcx>, &'tcx [ty::GenericArg<'tcx>]), ty::Clauses<'tcx>>>,
1487+
14821488
/// Data layout specification for the current target.
14831489
pub data_layout: TargetDataLayout,
14841490

@@ -1727,6 +1733,8 @@ impl<'tcx> TyCtxt<'tcx> {
17271733
new_solver_evaluation_cache: Default::default(),
17281734
new_solver_canonical_param_env_cache: Default::default(),
17291735
canonical_param_env_cache: Default::default(),
1736+
highest_var_in_clauses_cache: Default::default(),
1737+
clauses_cache: Default::default(),
17301738
data_layout,
17311739
alloc_map: interpret::AllocMap::new(),
17321740
current_gcx,

0 commit comments

Comments
 (0)