blob: da96611ec36669ba054ae187e08edd691b81cb15 [file] [log] [blame] [edit]
//! Simple union-find data structure.
use crate::trace;
use cranelift_entity::{packed_option::ReservedValue, EntityRef, SecondaryMap};
use std::hash::Hash;
use std::mem::swap;
/// A union-find data structure. The data structure can allocate
/// `Idx`s, indicating eclasses, and can merge eclasses together.
///
/// Running `union(a, b)` will change the canonical `Idx` of `a` or `b`.
/// Usually, this is chosen based on what will minimize path lengths,
/// but it is also possible to _pin_ an eclass, such that its canonical `Idx`
/// won't change unless it gets unioned with another pinned eclass.
///
/// In the context of the egraph pass, merging two pinned eclasses
/// is very unlikely to happen – we do not know a single concrete test case
/// where it does. The only situation where it might happen looks as follows:
///
/// 1. We encounter terms `A` and `B`, and the optimizer does not find any
/// reason to union them together.
/// 2. We encounter a term `C`, and we rewrite `C -> A`, and separately, `C -> B`.
///
/// Unless `C` somehow includes some crucial hint without which it is hard to
/// notice that `A = B`, there's probably a rewrite rule that we should add.
///
/// Worst case, if we do merge two pinned eclasses, some nodes will essentially
/// disappear from the GVN map, which only affects the quality of the generated
/// code.
#[derive(Clone, Debug, PartialEq)]
pub struct UnionFind<Idx: EntityRef> {
parent: SecondaryMap<Idx, Val<Idx>>,
/// The `rank` table is used to perform the union operations optimally,
/// without creating unnecessarily long paths. Pins are represented by
/// eclasses with a rank of `u8::MAX`.
///
/// `rank[x]` is the upper bound on the height of the subtree rooted at `x`.
/// The subtree is guaranteed to have at least `2**rank[x]` elements,
/// unless `rank` has been artificially inflated by pinning.
rank: SecondaryMap<Idx, u8>,
pub(crate) pinned_union_count: u64,
}
#[derive(Clone, Debug, PartialEq)]
struct Val<Idx>(Idx);
impl<Idx: EntityRef + ReservedValue> Default for Val<Idx> {
fn default() -> Self {
Self(Idx::reserved_value())
}
}
impl<Idx: EntityRef + Hash + std::fmt::Display + Ord + ReservedValue> UnionFind<Idx> {
/// Create a new `UnionFind` with the given capacity.
pub fn with_capacity(cap: usize) -> Self {
UnionFind {
parent: SecondaryMap::with_capacity(cap),
rank: SecondaryMap::with_capacity(cap),
pinned_union_count: 0,
}
}
/// Add an `Idx` to the `UnionFind`, with its own equivalence class
/// initially. All `Idx`s must be added before being queried or
/// unioned.
pub fn add(&mut self, id: Idx) {
debug_assert!(id != Idx::reserved_value());
self.parent[id] = Val(id);
}
/// Find the canonical `Idx` of a given `Idx`.
pub fn find(&self, mut node: Idx) -> Idx {
while node != self.parent[node].0 {
node = self.parent[node].0;
}
node
}
/// Find the canonical `Idx` of a given `Idx`, updating the data
/// structure in the process so that future queries for this `Idx`
/// (and others in its chain up to the root of the equivalence
/// class) will be faster.
pub fn find_and_update(&mut self, mut node: Idx) -> Idx {
// "Path halving" mutating find (Tarjan and Van Leeuwen).
debug_assert!(node != Idx::reserved_value());
while node != self.parent[node].0 {
let next = self.parent[self.parent[node].0].0;
debug_assert!(next != Idx::reserved_value());
self.parent[node] = Val(next);
node = next;
}
debug_assert!(node != Idx::reserved_value());
node
}
/// Request a stable identifier for `node`.
///
/// After an `union` operation, the canonical representative of one
/// of the eclasses being merged together necessarily changes. If a pinned
/// eclass is merged with a non-pinned eclass, it'll be the other eclass
/// whose representative will change.
///
/// If two pinned eclasses are unioned, one of the pins gets broken,
/// which is reported in the statistics for the pass. No concrete test case
/// which triggers this is known.
pub fn pin_index(&mut self, mut node: Idx) -> Idx {
node = self.find_and_update(node);
self.rank[node] = u8::MAX;
node
}
/// Merge the equivalence classes of the two `Idx`s.
pub fn union(&mut self, a: Idx, b: Idx) {
let mut a = self.find_and_update(a);
let mut b = self.find_and_update(b);
if a == b {
return;
}
if self.rank[a] < self.rank[b] {
swap(&mut a, &mut b);
} else if self.rank[a] == self.rank[b] {
self.rank[a] = self.rank[a].checked_add(1).unwrap_or_else(
#[cold]
|| {
// Both `a` and `b` are pinned.
//
// This should only occur if we rewrite X -> Y and X -> Z,
// yet neither Y -> Z nor Z -> Y can be established without
// the "hint" provided by X. This probably means we're
// missing an optimization rule.
self.pinned_union_count += 1;
u8::MAX
},
);
}
self.parent[b] = Val(a);
trace!("union: {}, {}", a, b);
}
}