Skip to content

Commit b606b49

Browse files
committed
Guard HIR lowered contracts with contract_checks
Refactor contract HIR lowering to ensure no contract code is executed when contract-checks are disabled. The call to contract_checks is moved to inside the lowered fn body, and contract closures are built conditionally, ensuring no side-effects present in contracts occur when those are disabled.
1 parent b56aaec commit b606b49

File tree

18 files changed

+459
-120
lines changed

18 files changed

+459
-120
lines changed
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
use crate::LoweringContext;
2+
3+
impl<'a, 'hir> LoweringContext<'a, 'hir> {
4+
pub(super) fn lower_contract(
5+
&mut self,
6+
body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
7+
contract: &rustc_ast::FnContract,
8+
) -> rustc_hir::Expr<'hir> {
9+
match (&contract.requires, &contract.ensures) {
10+
(Some(req), Some(ens)) => {
11+
// Lower the fn contract, which turns:
12+
//
13+
// { body }
14+
//
15+
// into:
16+
//
17+
// {
18+
// let __postcond = if contracts_checks() {
19+
// contract_check_requires(PRECOND);
20+
// Some(|ret_val| POSTCOND)
21+
// } else {
22+
// None
23+
// };
24+
// contract_check_ensures(__postcond, { body })
25+
// }
26+
27+
let precond = self.lower_precond(req);
28+
let postcond_checker = self.lower_postcond_checker(ens);
29+
30+
let contract_check =
31+
self.lower_contract_check_with_postcond(Some(precond), postcond_checker);
32+
33+
let wrapped_body =
34+
self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
35+
self.expr_block(wrapped_body)
36+
}
37+
(None, Some(ens)) => {
38+
// Lower the fn contract, which turns:
39+
//
40+
// { body }
41+
//
42+
// into:
43+
//
44+
// {
45+
// let __postcond = if contracts_check() {
46+
// Some(|ret_val| POSTCOND)
47+
// } else {
48+
// None
49+
// };
50+
// __postcond({ body })
51+
// }
52+
53+
let postcond_checker = self.lower_postcond_checker(ens);
54+
let contract_check =
55+
self.lower_contract_check_with_postcond(None, postcond_checker);
56+
57+
let wrapped_body =
58+
self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
59+
self.expr_block(wrapped_body)
60+
}
61+
(Some(req), None) => {
62+
// Lower the fn contract, which turns:
63+
//
64+
// { body }
65+
//
66+
// into:
67+
//
68+
// {
69+
// if contracts_check() {
70+
// contract_requires(PRECOND);
71+
// }
72+
// body
73+
// }
74+
let precond = self.lower_precond(req);
75+
let precond_check = self.lower_contract_check_just_precond(precond);
76+
77+
let body = self.arena.alloc(body(self));
78+
79+
// Flatten the body into precond check, then body.
80+
let wrapped_body = self.block_all(
81+
body.span,
82+
self.arena.alloc_from_iter([precond_check].into_iter()),
83+
Some(body),
84+
);
85+
self.expr_block(wrapped_body)
86+
}
87+
(None, None) => body(self),
88+
}
89+
}
90+
91+
/// Lower the precondition check intrinsic.
92+
fn lower_precond(&mut self, req: &Box<rustc_ast::Expr>) -> rustc_hir::Stmt<'hir> {
93+
let lowered_req = self.lower_expr_mut(&req);
94+
let req_span = self.mark_span_with_reason(
95+
rustc_span::DesugaringKind::Contract,
96+
lowered_req.span,
97+
None,
98+
);
99+
let precond = self.expr_call_lang_item_fn_mut(
100+
req_span,
101+
rustc_hir::LangItem::ContractCheckRequires,
102+
&*arena_vec![self; lowered_req],
103+
);
104+
self.stmt_expr(req.span, precond)
105+
}
106+
107+
fn lower_postcond_checker(
108+
&mut self,
109+
ens: &Box<rustc_ast::Expr>,
110+
) -> &'hir rustc_hir::Expr<'hir> {
111+
let ens_span = self.lower_span(ens.span);
112+
let ens_span =
113+
self.mark_span_with_reason(rustc_span::DesugaringKind::Contract, ens_span, None);
114+
let lowered_ens = self.lower_expr_mut(&ens);
115+
self.expr_call_lang_item_fn(
116+
ens_span,
117+
rustc_hir::LangItem::ContractBuildCheckEnsures,
118+
&*arena_vec![self; lowered_ens],
119+
)
120+
}
121+
122+
fn lower_contract_check_just_precond(
123+
&mut self,
124+
precond: rustc_hir::Stmt<'hir>,
125+
) -> rustc_hir::Stmt<'hir> {
126+
let stmts = self.arena.alloc_from_iter([precond].into_iter());
127+
128+
let then_block_stmts = self.block_all(precond.span, stmts, None);
129+
let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
130+
131+
let precond_check = rustc_hir::ExprKind::If(
132+
self.expr_call_lang_item_fn(
133+
precond.span,
134+
rustc_hir::LangItem::ContractChecks,
135+
Default::default(),
136+
),
137+
then_block,
138+
None,
139+
);
140+
141+
let precond_check = self.expr(precond.span, precond_check);
142+
self.stmt_expr(precond.span, precond_check)
143+
}
144+
145+
fn lower_contract_check_with_postcond(
146+
&mut self,
147+
precond: Option<rustc_hir::Stmt<'hir>>,
148+
postcond_checker: &'hir rustc_hir::Expr<'hir>,
149+
) -> &'hir rustc_hir::Expr<'hir> {
150+
let stmts = self.arena.alloc_from_iter(precond.into_iter());
151+
let span = match precond {
152+
Some(precond) => precond.span,
153+
None => postcond_checker.span,
154+
};
155+
156+
let postcond_checker = self.arena.alloc(self.expr_enum_variant_lang_item(
157+
postcond_checker.span,
158+
rustc_hir::lang_items::LangItem::OptionSome,
159+
&*arena_vec![self; *postcond_checker],
160+
));
161+
let then_block_stmts = self.block_all(span, stmts, Some(postcond_checker));
162+
let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
163+
164+
let none_expr = self.arena.alloc(self.expr_enum_variant_lang_item(
165+
postcond_checker.span,
166+
rustc_hir::lang_items::LangItem::OptionNone,
167+
Default::default(),
168+
));
169+
let else_block = self.block_expr(none_expr);
170+
let else_block = self.arena.alloc(self.expr_block(else_block));
171+
172+
let contract_check = rustc_hir::ExprKind::If(
173+
self.expr_call_lang_item_fn(
174+
span,
175+
rustc_hir::LangItem::ContractChecks,
176+
Default::default(),
177+
),
178+
then_block,
179+
Some(else_block),
180+
);
181+
self.arena.alloc(self.expr(span, contract_check))
182+
}
183+
184+
fn wrap_body_with_contract_check(
185+
&mut self,
186+
body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
187+
contract_check: &'hir rustc_hir::Expr<'hir>,
188+
postcond_span: rustc_span::Span,
189+
) -> &'hir rustc_hir::Block<'hir> {
190+
let check_ident: rustc_span::Ident =
191+
rustc_span::Ident::from_str_and_span("__ensures_checker", postcond_span);
192+
let (check_hir_id, postcond_decl) = {
193+
// Set up the postcondition `let` statement.
194+
let (checker_pat, check_hir_id) = self.pat_ident_binding_mode_mut(
195+
postcond_span,
196+
check_ident,
197+
rustc_hir::BindingMode::NONE,
198+
);
199+
(
200+
check_hir_id,
201+
self.stmt_let_pat(
202+
None,
203+
postcond_span,
204+
Some(contract_check),
205+
self.arena.alloc(checker_pat),
206+
rustc_hir::LocalSource::Contract,
207+
),
208+
)
209+
};
210+
211+
// Install contract_ensures so we will intercept `return` statements,
212+
// then lower the body.
213+
self.contract_ensures = Some((postcond_span, check_ident, check_hir_id));
214+
let body = self.arena.alloc(body(self));
215+
216+
// Finally, inject an ensures check on the implicit return of the body.
217+
let body = self.inject_ensures_check(body, postcond_span, check_ident, check_hir_id);
218+
219+
// Flatten the body into precond, then postcond, then wrapped body.
220+
let wrapped_body = self.block_all(
221+
body.span,
222+
self.arena.alloc_from_iter([postcond_decl].into_iter()),
223+
Some(body),
224+
);
225+
wrapped_body
226+
}
227+
228+
/// Create an `ExprKind::Ret` that is optionally wrapped by a call to check
229+
/// a contract ensures clause, if it exists.
230+
pub(super) fn checked_return(
231+
&mut self,
232+
opt_expr: Option<&'hir rustc_hir::Expr<'hir>>,
233+
) -> rustc_hir::ExprKind<'hir> {
234+
let checked_ret =
235+
if let Some((check_span, check_ident, check_hir_id)) = self.contract_ensures {
236+
let expr = opt_expr.unwrap_or_else(|| self.expr_unit(check_span));
237+
Some(self.inject_ensures_check(expr, check_span, check_ident, check_hir_id))
238+
} else {
239+
opt_expr
240+
};
241+
rustc_hir::ExprKind::Ret(checked_ret)
242+
}
243+
244+
/// Wraps an expression with a call to the ensures check before it gets returned.
245+
pub(super) fn inject_ensures_check(
246+
&mut self,
247+
expr: &'hir rustc_hir::Expr<'hir>,
248+
span: rustc_span::Span,
249+
cond_ident: rustc_span::Ident,
250+
cond_hir_id: rustc_hir::HirId,
251+
) -> &'hir rustc_hir::Expr<'hir> {
252+
let cond_fn = self.expr_ident(span, cond_ident, cond_hir_id);
253+
let call_expr = self.expr_call_lang_item_fn_mut(
254+
span,
255+
rustc_hir::LangItem::ContractCheckEnsures,
256+
arena_vec![self; *cond_fn, *expr],
257+
);
258+
self.arena.alloc(call_expr)
259+
}
260+
}

compiler/rustc_ast_lowering/src/expr.rs

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -380,36 +380,6 @@ impl<'hir> LoweringContext<'_, 'hir> {
380380
})
381381
}
382382

383-
/// Create an `ExprKind::Ret` that is optionally wrapped by a call to check
384-
/// a contract ensures clause, if it exists.
385-
fn checked_return(&mut self, opt_expr: Option<&'hir hir::Expr<'hir>>) -> hir::ExprKind<'hir> {
386-
let checked_ret =
387-
if let Some((check_span, check_ident, check_hir_id)) = self.contract_ensures {
388-
let expr = opt_expr.unwrap_or_else(|| self.expr_unit(check_span));
389-
Some(self.inject_ensures_check(expr, check_span, check_ident, check_hir_id))
390-
} else {
391-
opt_expr
392-
};
393-
hir::ExprKind::Ret(checked_ret)
394-
}
395-
396-
/// Wraps an expression with a call to the ensures check before it gets returned.
397-
pub(crate) fn inject_ensures_check(
398-
&mut self,
399-
expr: &'hir hir::Expr<'hir>,
400-
span: Span,
401-
cond_ident: Ident,
402-
cond_hir_id: HirId,
403-
) -> &'hir hir::Expr<'hir> {
404-
let cond_fn = self.expr_ident(span, cond_ident, cond_hir_id);
405-
let call_expr = self.expr_call_lang_item_fn_mut(
406-
span,
407-
hir::LangItem::ContractCheckEnsures,
408-
arena_vec![self; *cond_fn, *expr],
409-
);
410-
self.arena.alloc(call_expr)
411-
}
412-
413383
pub(crate) fn lower_const_block(&mut self, c: &AnonConst) -> hir::ConstBlock {
414384
self.with_new_scopes(c.value.span, |this| {
415385
let def_id = this.local_def_id(c.id);
@@ -2095,7 +2065,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
20952065
self.expr(span, hir::ExprKind::AddrOf(hir::BorrowKind::Ref, hir::Mutability::Mut, e))
20962066
}
20972067

2098-
fn expr_unit(&mut self, sp: Span) -> &'hir hir::Expr<'hir> {
2068+
pub(super) fn expr_unit(&mut self, sp: Span) -> &'hir hir::Expr<'hir> {
20992069
self.arena.alloc(self.expr(sp, hir::ExprKind::Tup(&[])))
21002070
}
21012071

@@ -2133,6 +2103,43 @@ impl<'hir> LoweringContext<'_, 'hir> {
21332103
self.expr(span, hir::ExprKind::Call(e, args))
21342104
}
21352105

2106+
pub(super) fn expr_struct(
2107+
&mut self,
2108+
span: Span,
2109+
path: &'hir hir::QPath<'hir>,
2110+
fields: &'hir [hir::ExprField<'hir>],
2111+
) -> hir::Expr<'hir> {
2112+
self.expr(span, hir::ExprKind::Struct(path, fields, rustc_hir::StructTailExpr::None))
2113+
}
2114+
2115+
pub(super) fn expr_enum_variant(
2116+
&mut self,
2117+
span: Span,
2118+
path: &'hir hir::QPath<'hir>,
2119+
fields: &'hir [hir::Expr<'hir>],
2120+
) -> hir::Expr<'hir> {
2121+
let fields = self.arena.alloc_from_iter(fields.into_iter().enumerate().map(|(i, f)| {
2122+
hir::ExprField {
2123+
hir_id: self.next_id(),
2124+
ident: Ident::from_str(&i.to_string()),
2125+
expr: f,
2126+
span: f.span,
2127+
is_shorthand: false,
2128+
}
2129+
}));
2130+
self.expr_struct(span, path, fields)
2131+
}
2132+
2133+
pub(super) fn expr_enum_variant_lang_item(
2134+
&mut self,
2135+
span: Span,
2136+
lang_item: hir::LangItem,
2137+
fields: &'hir [hir::Expr<'hir>],
2138+
) -> hir::Expr<'hir> {
2139+
let path = self.arena.alloc(self.lang_item_path(span, lang_item));
2140+
self.expr_enum_variant(span, path, fields)
2141+
}
2142+
21362143
pub(super) fn expr_call(
21372144
&mut self,
21382145
span: Span,
@@ -2161,8 +2168,21 @@ impl<'hir> LoweringContext<'_, 'hir> {
21612168
self.arena.alloc(self.expr_call_lang_item_fn_mut(span, lang_item, args))
21622169
}
21632170

2164-
fn expr_lang_item_path(&mut self, span: Span, lang_item: hir::LangItem) -> hir::Expr<'hir> {
2165-
self.expr(span, hir::ExprKind::Path(hir::QPath::LangItem(lang_item, self.lower_span(span))))
2171+
pub(super) fn expr_lang_item_path(
2172+
&mut self,
2173+
span: Span,
2174+
lang_item: hir::LangItem,
2175+
) -> hir::Expr<'hir> {
2176+
let path = self.lang_item_path(span, lang_item);
2177+
self.expr(span, hir::ExprKind::Path(path))
2178+
}
2179+
2180+
pub(super) fn lang_item_path(
2181+
&mut self,
2182+
span: Span,
2183+
lang_item: hir::LangItem,
2184+
) -> hir::QPath<'hir> {
2185+
hir::QPath::LangItem(lang_item, self.lower_span(span))
21662186
}
21672187

21682188
/// `<LangItem>::name`

0 commit comments

Comments
 (0)