Skip to content

Commit 4137c15

Browse files
committed
flow_control: Support otherwise pattern.
commit-id:cf815f94
1 parent 9bbab68 commit 4137c15

File tree

4 files changed

+53
-41
lines changed

4 files changed

+53
-41
lines changed

crates/cairo-lang-lowering/src/lower/flow_control/create_graph.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use cairo_lang_semantic::{self as semantic, Condition, PatternId};
22
use cairo_lang_syntax::node::TypedStablePtr;
33
use itertools::Itertools;
4-
use patterns::create_node_for_patterns;
4+
use patterns::{create_node_for_patterns, get_pattern};
55

66
use super::graph::{
7-
ArmExpr, BooleanIf, EvaluateExpr, FlowControlGraph, FlowControlGraphBuilder, FlowControlNode, NodeId,
7+
ArmExpr, BooleanIf, EvaluateExpr, FlowControlGraph, FlowControlGraphBuilder, FlowControlNode,
8+
NodeId,
89
};
910
use crate::lower::context::LoweringContext;
1011

@@ -64,10 +65,7 @@ pub fn create_graph_expr_if<'db>(
6465
ctx,
6566
&mut graph,
6667
expr_var,
67-
&patterns
68-
.iter()
69-
.map(|pattern| &ctx.function_body.arenas.patterns[*pattern])
70-
.collect_vec(),
68+
&patterns.iter().map(|pattern| Some(get_pattern(ctx, *pattern))).collect_vec(),
7169
&|_graph, pattern_indices| {
7270
if pattern_indices.first().is_some() { current_node } else { false_branch }
7371
},
@@ -119,7 +117,7 @@ pub fn create_graph_expr_match<'db>(
119117
matched_var,
120118
&pattern_and_nodes
121119
.iter()
122-
.map(|(pattern, _)| &ctx.function_body.arenas.patterns[*pattern])
120+
.map(|(pattern, _)| Some(get_pattern(ctx, *pattern)))
123121
.collect_vec(),
124122
&|_graph, pattern_indices| {
125123
// TODO(lior): add diagnostics if pattern_indices is empty (instead of `unwrap`).

crates/cairo-lang-lowering/src/lower/flow_control/create_graph/filtered_patterns.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,7 @@ impl FilteredPatterns {
4747
/// `foo` needs to lift it to `[2]` to return to its caller using `foo`'s indexing.
4848
pub fn lift(self, outer_filter: &FilteredPatterns) -> Self {
4949
Self {
50-
filter: self
51-
.filter
52-
.into_iter()
53-
.map(|index| outer_filter.filter[index])
54-
.collect_vec(),
50+
filter: self.filter.into_iter().map(|index| outer_filter.filter[index]).collect_vec(),
5551
}
5652
}
5753

crates/cairo-lang-lowering/src/lower/flow_control/create_graph/patterns.rs

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ use crate::lower::context::LoweringContext;
3939
type BuildNodeCallback<'db, 'a> =
4040
&'a dyn Fn(&mut FlowControlGraphBuilder<'db>, FilteredPatterns) -> NodeId;
4141

42+
/// A thin wrapper around [semantic::Pattern], where `None` represents the `_` pattern.
43+
pub type Pattern<'a, 'db> = Option<&'a semantic::Pattern<'db>>;
44+
4245
/// Given a list of patterns and the nodes to go to if the pattern matches,
4346
/// returns a new graph node to handle the patterns.
4447
///
@@ -50,17 +53,14 @@ pub fn create_node_for_patterns<'db>(
5053
ctx: &LoweringContext<'db, '_>,
5154
graph: &mut FlowControlGraphBuilder<'db>,
5255
input_var: FlowControlVar,
53-
patterns: &[&semantic::Pattern<'db>],
56+
patterns: &[Pattern<'_, 'db>],
5457
build_node_callback: BuildNodeCallback<'db, '_>,
5558
location: LocationId<'db>,
5659
) -> NodeId {
5760
// If all the patterns are catch-all, we do not need to look into `input_var`.
5861
if patterns.iter().all(|pattern| pattern_is_any(pattern)) {
5962
// Call the callback with all patterns accepted.
60-
return build_node_callback(
61-
graph,
62-
FilteredPatterns::all(patterns.len())
63-
);
63+
return build_node_callback(graph, FilteredPatterns::all(patterns.len()));
6464
}
6565

6666
let (n_snapshots, long_ty) = peel_snapshots(ctx.db, graph.var_ty(input_var));
@@ -87,7 +87,7 @@ fn create_node_for_enum<'db>(
8787
input_var: FlowControlVar,
8888
concrete_enum_id: ConcreteEnumId<'db>,
8989
n_snapshots: usize,
90-
patterns: &[&semantic::Pattern<'db>],
90+
patterns: &[Pattern<'_, 'db>],
9191
build_node_callback: BuildNodeCallback<'db, '_>,
9292
location: LocationId<'db>,
9393
) -> NodeId {
@@ -100,18 +100,28 @@ fn create_node_for_enum<'db>(
100100
// Maps variant index to the list of the inner patterns.
101101
// For example, a pattern `A(B(x))` will add the (inner) pattern `B(x)` to the vector at the
102102
// index of the variant `A`.
103-
let mut variant_to_inner_patterns: Vec<Vec<&semantic::Pattern>> = vec![vec![]; concrete_variants.len()];
103+
let mut variant_to_inner_patterns: Vec<Vec<Pattern>> = vec![vec![]; concrete_variants.len()];
104104

105105
for (idx, pattern) in patterns.iter().enumerate() {
106106
match pattern {
107-
semantic::Pattern::EnumVariant(PatternEnumVariant { variant, inner_pattern, .. }) => {
107+
Some(semantic::Pattern::EnumVariant(PatternEnumVariant {
108+
variant,
109+
inner_pattern,
110+
..
111+
})) => {
108112
variant_to_pattern_indices[variant.idx].add(idx);
109-
// TODO(lior): Fix the unwrap below.
110-
variant_to_inner_patterns[variant.idx].push(
111-
inner_pattern
112-
.map(|inner_pattern| &ctx.function_body.arenas.patterns[inner_pattern])
113-
.unwrap(),
114-
);
113+
variant_to_inner_patterns[variant.idx]
114+
.push(inner_pattern.map(|inner_pattern| get_pattern(ctx, inner_pattern)));
115+
}
116+
Some(semantic::Pattern::Otherwise(..)) | None => {
117+
// Add `idx` to all the variants.
118+
for pattern_indices in variant_to_pattern_indices.iter_mut() {
119+
pattern_indices.add(idx);
120+
}
121+
// Add the `_` pattern (represented by `None`) to all the variants.
122+
for inner_patterns in variant_to_inner_patterns.iter_mut() {
123+
inner_patterns.push(None);
124+
}
115125
}
116126
_ => todo!("Pattern {:?} is not supported yet.", pattern),
117127
}
@@ -146,16 +156,26 @@ fn create_node_for_enum<'db>(
146156
}
147157

148158
/// Returns `true` if the pattern accepts any value (`_` or a variable name).
149-
fn pattern_is_any(pattern: &semantic::Pattern) -> bool {
159+
fn pattern_is_any(pattern: &Pattern) -> bool {
150160
match pattern {
151-
semantic::Pattern::Otherwise(..)
152-
| semantic::Pattern::Variable(..) => true,
153-
semantic::Pattern::Literal(..) |
154-
semantic::Pattern::StringLiteral(..) |
155-
semantic::Pattern::Struct(..) |
156-
semantic::Pattern::Tuple(..) |
157-
semantic::Pattern::FixedSizeArray(..) |
158-
semantic::Pattern::EnumVariant(..) |
159-
semantic::Pattern::Missing(..) => false,
161+
Some(semantic_pattern) => match semantic_pattern {
162+
semantic::Pattern::Otherwise(..) | semantic::Pattern::Variable(..) => true,
163+
semantic::Pattern::Literal(..)
164+
| semantic::Pattern::StringLiteral(..)
165+
| semantic::Pattern::Struct(..)
166+
| semantic::Pattern::Tuple(..)
167+
| semantic::Pattern::FixedSizeArray(..)
168+
| semantic::Pattern::EnumVariant(..)
169+
| semantic::Pattern::Missing(..) => false,
170+
},
171+
None => true,
160172
}
161173
}
174+
175+
/// Returns a reference to a [semantic::Pattern] from a [semantic::PatternId].
176+
pub fn get_pattern<'db, 'a>(
177+
ctx: &'a LoweringContext<'db, '_>,
178+
semantic_pattern: semantic::PatternId,
179+
) -> &'a semantic::Pattern<'db> {
180+
&ctx.function_body.arenas.patterns[semantic_pattern]
181+
}

crates/cairo-lang-lowering/src/lower/flow_control/test_data/match

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ test_create_graph(expect_diagnostics: false)
66
//! > function_code
77
fn foo(color: Color) -> felt252 {
88
match color {
9-
// TODO(lior): Remove some of the `(_)` once supported.
10-
Color::Red(_) | Color::Green(_) => 1,
11-
Color::Red(_) | Color::Blue(_) => 2,
12-
// TODO(lior): Change to `_` once supported.
13-
Color::Black(_) | Color::White(_) => 3,
9+
Color::Red | Color::Green(_) => 1,
10+
Color::Red(_) | Color::Blue => 2,
11+
_ => 3,
1412
}
1513
}
1614

0 commit comments

Comments
 (0)