Skip to content

Commit cdb6ddb

Browse files
committed
update
1 parent e768a48 commit cdb6ddb

1 file changed

Lines changed: 84 additions & 56 deletions

File tree

src/sql/logical_planner/optimizers/chaining.rs

Lines changed: 84 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -10,93 +10,121 @@
1010
// See the License for the specific language governing permissions and
1111
// limitations under the License.
1212

13-
use std::mem;
14-
15-
use petgraph::prelude::*;
16-
use petgraph::visit::NodeRef;
13+
use petgraph::graph::{EdgeIndex, NodeIndex};
14+
use petgraph::visit::EdgeRef;
15+
use petgraph::Direction::{Incoming, Outgoing};
1716

1817
use crate::sql::logical_node::logical::{LogicalEdgeType, LogicalGraph, Optimizer};
1918

20-
pub struct ChainingOptimizer {}
21-
22-
fn remove_in_place<N, E>(graph: &mut DiGraph<N, E>, node: NodeIndex) {
23-
let incoming = graph.edges_directed(node, Incoming).next().unwrap();
24-
25-
let parent = incoming.source().id();
26-
let incoming = incoming.id();
27-
graph.remove_edge(incoming);
19+
pub type NodeId = NodeIndex;
20+
pub type EdgeId = EdgeIndex;
2821

29-
let outgoing: Vec<_> = graph
30-
.edges_directed(node, Outgoing)
31-
.map(|e| (e.id(), e.target().id()))
32-
.collect();
22+
pub struct ChainingOptimizer;
3323

34-
for (edge, target) in outgoing {
35-
let weight = graph.remove_edge(edge).unwrap();
36-
graph.add_edge(parent, target, weight);
37-
}
24+
impl ChainingOptimizer {
25+
fn find_fusion_candidate(plan: &LogicalGraph) -> Option<(NodeId, NodeId, EdgeId)> {
26+
let node_ids: Vec<NodeId> = plan.node_indices().collect();
3827

39-
graph.remove_node(node);
40-
}
28+
for upstream_id in node_ids {
29+
let upstream_node = plan.node_weight(upstream_id)?;
4130

42-
impl Optimizer for ChainingOptimizer {
43-
fn optimize_once(&self, plan: &mut LogicalGraph) -> bool {
44-
let node_indices: Vec<NodeIndex> = plan.node_indices().collect();
45-
46-
for &node_idx in &node_indices {
47-
let cur = plan.node_weight(node_idx).unwrap();
48-
49-
if cur.operator_chain.is_source() {
31+
if upstream_node.operator_chain.is_source() {
5032
continue;
5133
}
5234

53-
let mut successors = plan.edges_directed(node_idx, Outgoing).collect::<Vec<_>>();
35+
let outgoing_edges: Vec<_> = plan.edges_directed(upstream_id, Outgoing).collect();
5436

55-
if successors.len() != 1 {
37+
if outgoing_edges.len() != 1 {
5638
continue;
5739
}
5840

59-
let edge = successors.remove(0);
60-
let edge_type = edge.weight().edge_type;
41+
let bridging_edge = &outgoing_edges[0];
6142

62-
if edge_type != LogicalEdgeType::Forward {
43+
if bridging_edge.weight().edge_type != LogicalEdgeType::Forward {
6344
continue;
6445
}
6546

66-
let successor_idx = edge.target();
47+
let downstream_id = bridging_edge.target();
48+
let downstream_node = plan.node_weight(downstream_id)?;
6749

68-
let successor_node = plan.node_weight(successor_idx).unwrap();
50+
if downstream_node.operator_chain.is_sink() {
51+
continue;
52+
}
6953

70-
if cur.parallelism != successor_node.parallelism
71-
|| successor_node.operator_chain.is_sink()
72-
{
54+
if upstream_node.parallelism != downstream_node.parallelism {
7355
continue;
7456
}
7557

76-
if plan.edges_directed(successor_idx, Incoming).count() > 1 {
58+
let incoming_edges: Vec<_> = plan.edges_directed(downstream_id, Incoming).collect();
59+
if incoming_edges.len() != 1 {
7760
continue;
7861
}
7962

80-
let mut new_cur = cur.clone();
63+
return Some((upstream_id, downstream_id, bridging_edge.id()));
64+
}
65+
66+
None
67+
}
8168

82-
new_cur.description = format!("{} -> {}", cur.description, successor_node.description);
69+
fn apply_fusion(
70+
plan: &mut LogicalGraph,
71+
upstream_id: NodeId,
72+
downstream_id: NodeId,
73+
bridging_edge_id: EdgeId,
74+
) {
75+
let bridging_edge = plan
76+
.remove_edge(bridging_edge_id)
77+
.expect("Graph Integrity Violation: Bridging edge missing");
78+
79+
let propagated_schema = bridging_edge.schema.clone();
80+
81+
let downstream_outgoing: Vec<_> = plan
82+
.edges_directed(downstream_id, Outgoing)
83+
.map(|e| (e.id(), e.target()))
84+
.collect();
85+
86+
for (edge_id, target_id) in downstream_outgoing {
87+
let edge_weight = plan
88+
.remove_edge(edge_id)
89+
.expect("Graph Integrity Violation: Outgoing edge missing");
90+
91+
plan.add_edge(upstream_id, target_id, edge_weight);
92+
}
8393

84-
new_cur
85-
.operator_chain
86-
.operators
87-
.extend(successor_node.operator_chain.operators.clone());
94+
let downstream_node = plan
95+
.remove_node(downstream_id)
96+
.expect("Graph Integrity Violation: Downstream node missing");
8897

89-
new_cur
90-
.operator_chain
91-
.edges
92-
.push(edge.weight().schema.clone());
98+
let upstream_node = plan
99+
.node_weight_mut(upstream_id)
100+
.expect("Graph Integrity Violation: Upstream node missing");
93101

94-
mem::swap(&mut new_cur, plan.node_weight_mut(node_idx).unwrap());
102+
upstream_node.description = format!(
103+
"{} -> {}",
104+
upstream_node.description, downstream_node.description
105+
);
95106

96-
remove_in_place(plan, successor_idx);
97-
return true;
98-
}
107+
upstream_node
108+
.operator_chain
109+
.operators
110+
.extend(downstream_node.operator_chain.operators);
99111

100-
false
112+
upstream_node
113+
.operator_chain
114+
.edges
115+
.push(propagated_schema);
116+
}
117+
}
118+
119+
impl Optimizer for ChainingOptimizer {
120+
fn optimize_once(&self, plan: &mut LogicalGraph) -> bool {
121+
if let Some((upstream_id, downstream_id, bridging_edge_id)) =
122+
Self::find_fusion_candidate(plan)
123+
{
124+
Self::apply_fusion(plan, upstream_id, downstream_id, bridging_edge_id);
125+
true
126+
} else {
127+
false
128+
}
101129
}
102130
}

0 commit comments

Comments
 (0)