Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/redpiler/src/compile_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl CompileNode {
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LinkType {
Default,
Side,
Expand Down
2 changes: 2 additions & 0 deletions crates/redpiler/src/passes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod export_graph;
mod identify_nodes;
mod input_search;
mod prune_orphans;
mod redundant_links;
mod unreachable_output;

use mchprs_world::World;
Expand All @@ -30,6 +31,7 @@ pub const fn make_default_pass_manager<'w, W: World>() -> PassManager<'w, W> {
&clamp_weights::ClampWeights,
&dedup_links::DedupLinks,
&constant_fold::ConstantFold,
&redundant_links::PruneRedundantLinks,
&analysis::ss_range_analysis::SSRangeAnalysis,
&unreachable_output::UnreachableOutput,
&constant_coalesce::ConstantCoalesce,
Expand Down
112 changes: 112 additions & 0 deletions crates/redpiler/src/passes/redundant_links.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use std::collections::hash_map::Entry;

use super::Pass;
use crate::compile_graph::{CompileGraph, LinkType, NodeIdx, NodeType};
use crate::passes::AnalysisInfos;
use crate::{CompilerInput, CompilerOptions};
use mchprs_blocks::blocks::ComparatorMode;
use mchprs_world::World;
use petgraph::visit::EdgeRef;
use petgraph::Direction;
use rustc_hash::FxHashMap;
use tracing::trace;

pub struct PruneRedundantLinks;

impl<W: World> Pass<W> for PruneRedundantLinks {
fn run_pass(
&self,
graph: &mut CompileGraph,
_: &CompilerOptions,
_: &CompilerInput<'_, W>,
_: &mut AnalysisInfos,
) {
let mut num_edges_pruned = 0;
let node_indices = graph.node_indices().collect::<Vec<_>>();
for idx in node_indices {
num_edges_pruned += match graph[idx].ty {
NodeType::Comparator {
mode,
far_input: None,
..
} => prune_comparator_inputs(graph, idx, mode),
_ => 0,
};
}
trace!("Removed {num_edges_pruned} edges.");
}

fn status_message(&self) -> &'static str {
"Pruning redundant links"
}
}

/// Whenever a node has links to both the default input and the side input of a comparator,
/// only one of those links actually has an effect on the comparator (dominating link).
/// The other link's effect is always cancelled out by the dominating link.
/// This function determines the dominating link and removes the other.
///
/// The case where a node connects to both the default input and side input of a comparator is most
/// commonly seen in XOR gates implemented by 2 comparators in subtract mode.
fn prune_comparator_inputs(
graph: &mut CompileGraph,
idx: NodeIdx,
comparator_mode: ComparatorMode,
) -> usize {
let mut input_distances: FxHashMap<(NodeIdx, LinkType), u8> = FxHashMap::default();
for edge in graph.edges_directed(idx, Direction::Incoming) {
match input_distances.entry((edge.source(), edge.weight().ty)) {
Entry::Occupied(occupied) => {
let cur_distance = occupied.into_mut();
*cur_distance = std::cmp::min(*cur_distance, edge.weight().ss);
}
Entry::Vacant(vacant) => {
vacant.insert(edge.weight().ss);
}
};
}

let mut edges_to_be_removed = Vec::new();
for edge in graph.edges_directed(idx, Direction::Incoming) {
let default_distance = *input_distances
.get(&(edge.source(), LinkType::Default))
.unwrap_or(&u8::MAX);
let side_distance = *input_distances
.get(&(edge.source(), LinkType::Side))
.unwrap_or(&u8::MAX);
let dominated_input =
dominated_comparator_input(comparator_mode, default_distance, side_distance);
if Some(edge.weight().ty) == dominated_input {
edges_to_be_removed.push(edge.id());
}
}

let num_edges_pruned = edges_to_be_removed.len();
for edge in edges_to_be_removed {
graph.remove_edge(edge);
}
num_edges_pruned
}

fn dominated_comparator_input(
comparator_mode: ComparatorMode,
default_distance: u8,
side_distance: u8,
) -> Option<LinkType> {
match comparator_mode {
ComparatorMode::Compare => {
if default_distance <= side_distance {
Some(LinkType::Side)
} else {
Some(LinkType::Default)
}
}
ComparatorMode::Subtract => {
if side_distance <= default_distance {
Some(LinkType::Default)
} else {
None
}
}
}
}