diff --git a/libs/@local/hashql/mir/src/body/operand.rs b/libs/@local/hashql/mir/src/body/operand.rs index 8652917601d..b64ff6ac739 100644 --- a/libs/@local/hashql/mir/src/body/operand.rs +++ b/libs/@local/hashql/mir/src/body/operand.rs @@ -3,7 +3,7 @@ //! Operands represent the inputs to MIR operations. They can either reference //! a storage location (place) or contain an immediate constant value. -use super::{constant::Constant, place::Place}; +use super::{constant::Constant, local::Local, place::Place}; /// An operand in a HashQL MIR operation. /// @@ -59,6 +59,12 @@ impl<'heap> Operand<'heap> { } } +impl From for Operand<'_> { + fn from(local: Local) -> Self { + Operand::Place(Place::local(local)) + } +} + impl<'heap> From> for Operand<'heap> { fn from(place: Place<'heap>) -> Self { Operand::Place(place) diff --git a/libs/@local/hashql/mir/src/body/place.rs b/libs/@local/hashql/mir/src/body/place.rs index 5a50a35480e..d39fed2eb2f 100644 --- a/libs/@local/hashql/mir/src/body/place.rs +++ b/libs/@local/hashql/mir/src/body/place.rs @@ -315,10 +315,11 @@ impl<'heap> Place<'heap> { /// This is the simplest form of a place, representing direct access to a local variable /// without navigating through any structured data. The resulting place has an empty /// projection sequence. - pub fn local(local: Local, interner: &Interner<'heap>) -> Self { + #[must_use] + pub const fn local(local: Local) -> Self { Self { local, - projections: interner.projections.intern_slice(&[]), + projections: Interned::empty(), } } diff --git a/libs/@local/hashql/mir/src/body/terminator/switch_int.rs b/libs/@local/hashql/mir/src/body/terminator/switch_int.rs index d7ab7ffc9aa..40ba6c8eeab 100644 --- a/libs/@local/hashql/mir/src/body/terminator/switch_int.rs +++ b/libs/@local/hashql/mir/src/body/terminator/switch_int.rs @@ -74,20 +74,17 @@ pub enum SwitchIntValue { /// let targets = SwitchTargets::new( /// &heap, /// [ -/// (0, Target::block(bb0, &interner)), -/// (1, Target::block(bb1, &interner)), -/// (2, Target::block(bb2, &interner)), +/// (0, Target::block(bb0)), +/// (1, Target::block(bb1)), +/// (2, Target::block(bb2)), /// ], -/// Some(Target::block(otherwise, &interner)), +/// Some(Target::block(otherwise)), /// ); /// /// // Values are automatically sorted /// assert_eq!(targets.values(), &[0, 1, 2]); -/// assert_eq!(targets.target(1), Some(Target::block(bb1, &interner))); -/// assert_eq!( -/// targets.target(99), -/// Some(Target::block(otherwise, &interner)) -/// ); +/// assert_eq!(targets.target(1), Some(Target::block(bb1))); +/// assert_eq!(targets.target(99), Some(Target::block(otherwise))); /// ``` #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct SwitchTargets<'heap> { @@ -131,11 +128,11 @@ impl<'heap> SwitchTargets<'heap> { /// let targets = SwitchTargets::new( /// &heap, /// [ - /// (10, Target::block(BasicBlockId::new(0), &interner)), - /// (20, Target::block(BasicBlockId::new(1), &interner)), - /// (30, Target::block(BasicBlockId::new(2), &interner)), + /// (10, Target::block(BasicBlockId::new(0))), + /// (20, Target::block(BasicBlockId::new(1))), + /// (30, Target::block(BasicBlockId::new(2))), /// ], - /// Some(Target::block(BasicBlockId::new(3), &interner)), + /// Some(Target::block(BasicBlockId::new(3))), /// ); /// /// assert_eq!(targets.values(), &[10, 20, 30]); @@ -158,11 +155,7 @@ impl<'heap> SwitchTargets<'heap> { /// let heap = Heap::new(); /// let interner = Interner::new(&heap); /// - /// let targets = SwitchTargets::new( - /// &heap, - /// [(0, Target::block(BasicBlockId::new(0), &interner))], - /// None, - /// ); + /// let targets = SwitchTargets::new(&heap, [(0, Target::block(BasicBlockId::new(0)))], None); /// /// assert_eq!(targets.otherwise(), None); /// assert_eq!(targets.target(99), None); // No otherwise, so None for unmatched values @@ -214,8 +207,8 @@ impl<'heap> SwitchTargets<'heap> { /// let heap = Heap::new(); /// let interner = Interner::new(&heap); /// - /// let then_block = Target::block(BasicBlockId::new(1), &interner); - /// let else_block = Target::block(BasicBlockId::new(2), &interner); + /// let then_block = Target::block(BasicBlockId::new(1)); + /// let else_block = Target::block(BasicBlockId::new(2)); /// /// let targets = SwitchTargets::new_if(&heap, then_block, else_block); /// @@ -261,8 +254,8 @@ impl<'heap> SwitchTargets<'heap> { /// let heap = Heap::new(); /// let interner = Interner::new(&heap); /// - /// let then_block = Target::block(BasicBlockId::new(1), &interner); - /// let else_block = Target::block(BasicBlockId::new(2), &interner); + /// let then_block = Target::block(BasicBlockId::new(1)); + /// let else_block = Target::block(BasicBlockId::new(2)); /// /// // Binary switch can be converted /// let binary = SwitchTargets::new_if(&heap, then_block, else_block); @@ -316,11 +309,11 @@ impl<'heap> SwitchTargets<'heap> { /// let heap = Heap::new(); /// let interner = Interner::new(&heap); /// - /// let default = Target::block(BasicBlockId::new(99), &interner); + /// let default = Target::block(BasicBlockId::new(99)); /// /// let with_otherwise = SwitchTargets::new( /// &heap, - /// [(1, Target::block(BasicBlockId::new(0), &interner))], + /// [(1, Target::block(BasicBlockId::new(0)))], /// Some(default), /// ); /// assert_eq!(with_otherwise.otherwise(), Some(default)); @@ -353,9 +346,9 @@ impl<'heap> SwitchTargets<'heap> { /// let heap = Heap::new(); /// let interner = Interner::new(&heap); /// - /// let bb0 = Target::block(BasicBlockId::new(0), &interner); - /// let bb1 = Target::block(BasicBlockId::new(1), &interner); - /// let otherwise = Target::block(BasicBlockId::new(99), &interner); + /// let bb0 = Target::block(BasicBlockId::new(0)); + /// let bb1 = Target::block(BasicBlockId::new(1)); + /// let otherwise = Target::block(BasicBlockId::new(99)); /// /// let targets = SwitchTargets::new(&heap, [(10, bb0), (20, bb1)], Some(otherwise)); /// @@ -417,8 +410,8 @@ impl<'heap> SwitchTargets<'heap> { /// let mut targets = SwitchTargets::new( /// &heap, /// [ - /// (1, Target::block(BasicBlockId::new(0), &interner)), - /// (2, Target::block(BasicBlockId::new(1), &interner)), + /// (1, Target::block(BasicBlockId::new(0))), + /// (2, Target::block(BasicBlockId::new(1))), /// ], /// None, /// ); @@ -459,10 +452,10 @@ impl<'heap> SwitchTargets<'heap> { /// let targets = SwitchTargets::new( /// &heap, /// [ - /// (10, Target::block(BasicBlockId::new(0), &interner)), - /// (20, Target::block(BasicBlockId::new(1), &interner)), + /// (10, Target::block(BasicBlockId::new(0))), + /// (20, Target::block(BasicBlockId::new(1))), /// ], - /// Some(Target::block(BasicBlockId::new(99), &interner)), + /// Some(Target::block(BasicBlockId::new(99))), /// ); /// /// let pairs: Vec<_> = targets.iter().collect(); @@ -501,8 +494,8 @@ impl<'heap> SwitchTargets<'heap> { /// let mut targets = SwitchTargets::new( /// &heap, /// [ - /// (10, Target::block(BasicBlockId::new(0), &interner)), - /// (20, Target::block(BasicBlockId::new(1), &interner)), + /// (10, Target::block(BasicBlockId::new(0))), + /// (20, Target::block(BasicBlockId::new(1))), /// ], /// None, /// ); @@ -545,8 +538,8 @@ impl<'heap> SwitchTargets<'heap> { /// /// let mut targets = SwitchTargets::new(&heap, [], None); /// - /// targets.add_target(10, Target::block(BasicBlockId::new(0), &interner)); - /// targets.add_target(5, Target::block(BasicBlockId::new(1), &interner)); + /// targets.add_target(10, Target::block(BasicBlockId::new(0))); + /// targets.add_target(5, Target::block(BasicBlockId::new(1))); /// /// // Values are kept sorted /// assert_eq!(targets.values(), &[5, 10]); @@ -587,16 +580,12 @@ impl<'heap> SwitchTargets<'heap> { /// let heap = Heap::new(); /// let interner = Interner::new(&heap); /// - /// let mut first = SwitchTargets::new( - /// &heap, - /// [(10, Target::block(BasicBlockId::new(0), &interner))], - /// None, - /// ); + /// let mut first = SwitchTargets::new(&heap, [(10, Target::block(BasicBlockId::new(0)))], None); /// /// let mut second = SwitchTargets::new( /// &heap, - /// [(20, Target::block(BasicBlockId::new(1), &interner))], - /// Some(Target::block(BasicBlockId::new(99), &interner)), + /// [(20, Target::block(BasicBlockId::new(1)))], + /// Some(Target::block(BasicBlockId::new(99))), /// ); /// /// first.append(&mut second); @@ -702,11 +691,11 @@ impl<'heap> SwitchTargets<'heap> { /// let targets = SwitchTargets::new( /// &heap, /// [ -/// (0, Target::block(BasicBlockId::new(0), &interner)), -/// (1, Target::block(BasicBlockId::new(1), &interner)), -/// (2, Target::block(BasicBlockId::new(2), &interner)), +/// (0, Target::block(BasicBlockId::new(0))), +/// (1, Target::block(BasicBlockId::new(1))), +/// (2, Target::block(BasicBlockId::new(2))), /// ], -/// Some(Target::block(BasicBlockId::new(3), &interner)), // otherwise +/// Some(Target::block(BasicBlockId::new(3))), // otherwise /// ); /// /// // Create the switch with an integer discriminant @@ -734,12 +723,12 @@ impl<'heap> SwitchTargets<'heap> { /// let heap = Heap::new(); /// let interner = Interner::new(&heap); /// -/// let then_target = Target::block(BasicBlockId::new(1), &interner); -/// let else_target = Target::block(BasicBlockId::new(2), &interner); +/// let then_target = Target::block(BasicBlockId::new(1)); +/// let else_target = Target::block(BasicBlockId::new(2)); /// /// // Create a binary switch for if-else /// let switch = SwitchInt { -/// discriminant: Operand::Place(Place::local(Local::new(0), &interner)), +/// discriminant: Operand::Place(Place::local(Local::new(0))), /// targets: SwitchTargets::new_if(&heap, then_target, else_target), /// }; /// diff --git a/libs/@local/hashql/mir/src/body/terminator/target.rs b/libs/@local/hashql/mir/src/body/terminator/target.rs index 2940d9d5061..9a92fd6e562 100644 --- a/libs/@local/hashql/mir/src/body/terminator/target.rs +++ b/libs/@local/hashql/mir/src/body/terminator/target.rs @@ -5,10 +5,7 @@ use hashql_core::intern::Interned; -use crate::{ - body::{basic_block::BasicBlockId, operand::Operand}, - intern::Interner, -}; +use crate::body::{basic_block::BasicBlockId, operand::Operand}; /// A control flow target in the HashQL MIR. /// @@ -39,11 +36,12 @@ pub struct Target<'heap> { pub args: Interned<'heap, [Operand<'heap>]>, } -impl<'heap> Target<'heap> { - pub fn block(block: impl Into, interner: &Interner<'heap>) -> Self { +impl Target<'_> { + #[must_use] + pub const fn block(block: BasicBlockId) -> Self { Self { - block: block.into(), - args: interner.operands.intern_slice(&[]), + block, + args: Interned::empty(), } } } diff --git a/libs/@local/hashql/mir/src/builder/body.rs b/libs/@local/hashql/mir/src/builder/body.rs index 1240857b395..2f1337352a8 100644 --- a/libs/@local/hashql/mir/src/builder/body.rs +++ b/libs/@local/hashql/mir/src/builder/body.rs @@ -70,7 +70,7 @@ impl<'env, 'heap> BodyBuilder<'env, 'heap> { }; let local = self.local_decls.push(decl); - Place::local(local, self.interner) + Place::local(local) } /// Reserves a new basic block and returns its ID. diff --git a/libs/@local/hashql/mir/src/pass/analysis/data_dependency/resolve.rs b/libs/@local/hashql/mir/src/pass/analysis/data_dependency/resolve.rs index 3b09f4ee85b..3886a73c94f 100644 --- a/libs/@local/hashql/mir/src/pass/analysis/data_dependency/resolve.rs +++ b/libs/@local/hashql/mir/src/pass/analysis/data_dependency/resolve.rs @@ -299,7 +299,7 @@ fn resolve_params_const<'heap, A: Allocator + Clone>( } else { // We have finished (we have terminated on a param, which is divergent, therefore the place // is still valid, just doesn't have a constant value) - ResolutionResult::Resolved(Operand::Place(Place::local(place.local, state.interner))) + ResolutionResult::Resolved(Operand::Place(Place::local(place.local))) } } @@ -382,15 +382,14 @@ pub(crate) fn resolve<'heap, A: Allocator + Clone>( [] => { // Base case: no more projections to resolve. // Check for constant propagation through Load. - let operand = if let Some(constant) = state + let operand = state .graph .constant_bindings .find_by_kind(place.local, EdgeKind::Load) - { - Operand::Constant(constant) - } else { - Operand::Place(Place::local(place.local, state.interner)) - }; + .map_or_else( + || Operand::Place(Place::local(place.local)), + Operand::Constant, + ); return ResolutionResult::Resolved(operand); } diff --git a/libs/@local/hashql/mir/src/pass/transform/administrative_reduction/visitor.rs b/libs/@local/hashql/mir/src/pass/transform/administrative_reduction/visitor.rs index 1707d1aef1a..5a4c37c8a17 100644 --- a/libs/@local/hashql/mir/src/pass/transform/administrative_reduction/visitor.rs +++ b/libs/@local/hashql/mir/src/pass/transform/administrative_reduction/visitor.rs @@ -221,7 +221,7 @@ impl<'heap, A: Allocator> AdministrativeReductionVisitor<'_, '_, 'heap, A> { .enumerate() .map(|(param, argument)| Statement { kind: StatementKind::Assign(Assign { - lhs: Place::local(Local::new(local_offset + param), self.interner), + lhs: Place::local(Local::new(local_offset + param)), rhs: RValue::Load(argument), }), span, diff --git a/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/mod.rs b/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/mod.rs index 8a69a40d9b2..3e545a6d420 100644 --- a/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/mod.rs @@ -122,12 +122,7 @@ impl CfgSimplify { /// 3. Replace `A`'s terminator with `B`'s terminator /// /// SSA invariants may be temporarily broken; the [`SsaRepair`] runs afterward to fix them. - fn simplify_goto<'heap>( - context: &MirContext<'_, 'heap>, - body: &mut Body<'heap>, - id: BasicBlockId, - goto: Goto<'heap>, - ) -> bool { + fn simplify_goto<'heap>(body: &mut Body<'heap>, id: BasicBlockId, goto: Goto<'heap>) -> bool { // Self-loops cannot be optimized as there's no simplification possible. if goto.target.block == id { return false; @@ -170,7 +165,7 @@ impl CfgSimplify { block.statements.push(Statement { span: block.terminator.span, kind: StatementKind::Assign(Assign { - lhs: Place::local(param, context.interner), + lhs: Place::local(param), rhs: RValue::Load(arg), }), }); @@ -424,7 +419,7 @@ impl CfgSimplify { .transfer_into(&self.alloc); let changed = match &body.basic_blocks[id].terminator.kind { - &TerminatorKind::Goto(goto) => Self::simplify_goto(context, body, id, goto), + &TerminatorKind::Goto(goto) => Self::simplify_goto(body, id, goto), TerminatorKind::SwitchInt(_) => Self::simplify_switch_int(context, body, id), TerminatorKind::Return(_) | TerminatorKind::GraphRead(_) diff --git a/libs/@local/hashql/mir/src/pass/transform/copy_propagation/mod.rs b/libs/@local/hashql/mir/src/pass/transform/copy_propagation/mod.rs index 8d7cc2a3982..3570242e3e8 100644 --- a/libs/@local/hashql/mir/src/pass/transform/copy_propagation/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/copy_propagation/mod.rs @@ -1,27 +1,45 @@ //! Copy and constant propagation transformation pass. //! -//! This pass propagates constant values through the MIR by tracking which locals hold known -//! constants and substituting uses of those locals with the constant values directly. +//! This pass propagates both constants and copies through the MIR by tracking which locals hold +//! known values (either constants or references to other locals) and substituting uses accordingly. //! //! Unlike [`ForwardSubstitution`], this pass does not perform full data dependency analysis and //! cannot resolve values through projections or chained access paths. It is faster but less -//! comprehensive, making it suitable for quick constant folding in simpler cases. +//! comprehensive, making it suitable for quick propagation in simpler cases. //! //! # Algorithm //! //! The pass operates in a single forward traversal (reverse postorder): //! -//! 1. For each block, propagates constants through block parameters when all predecessors pass the -//! same constant value -//! 2. For each assignment `_x = `, if the operand is a constant or a local known to hold a -//! constant, records that `_x` holds that constant -//! 3. For each use of a local known to hold a constant, substitutes the use with the constant +//! 1. For each block, propagates values through block parameters when all predecessors pass the +//! same value (constant or local) +//! 2. For each assignment `_x = `, records what `_x` holds: +//! - If the operand is a constant, records `_x → constant` +//! - If the operand is a local (possibly with a known value), records `_x → known value` +//! 3. For each use of a local with a known value, substitutes the use with that value +//! +//! # Examples +//! +//! Constant propagation: +//! ```text +//! _1 = const 42; use(_1) → _1 = const 42; use(const 42) +//! ``` +//! +//! Copy propagation: +//! ```text +//! _2 = _1; use(_2) → _2 = _1; use(_1) +//! ``` +//! +//! Chained propagation: +//! ```text +//! _2 = _1; _3 = _2; use(_3) → _2 = _1; _3 = _1; use(_1) +//! ``` //! //! # Limitations //! //! - Does not handle projections: `_2 = (_1,); use(_2.0)` is not simplified //! - Does not perform fix-point iteration for loops -//! - Only tracks constants, not arbitrary value equivalences +//! - Assumes SSA-like semantics (locals are assigned at most once) //! //! For more comprehensive value propagation including projections, see [`ForwardSubstitution`]. //! @@ -46,6 +64,7 @@ use crate::{ local::{Local, LocalVec}, location::Location, operand::Operand, + place::Place, rvalue::RValue, statement::Assign, }, @@ -165,6 +184,21 @@ where .flatten() } +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +enum KnownValue<'heap> { + Constant(Constant<'heap>), + Local(Local), +} + +impl<'heap> From> for Operand<'heap> { + fn from(value: KnownValue<'heap>) -> Self { + match value { + KnownValue::Constant(constant) => Operand::Constant(constant), + KnownValue::Local(local) => Operand::Place(Place::local(local)), + } + } +} + pub struct CopyPropagation { alloc: A, } @@ -196,7 +230,7 @@ impl<'env, 'heap, A: ResetAllocator> TransformPass<'env, 'heap> for CopyPropagat let mut visitor = CopyPropagationVisitor { interner: context.interner, - constants: IdVec::with_capacity_in(body.local_decls.len(), &self.alloc), + values: IdVec::with_capacity_in(body.local_decls.len(), &self.alloc), changed: false, }; @@ -208,10 +242,10 @@ impl<'env, 'heap, A: ResetAllocator> TransformPass<'env, 'heap> for CopyPropagat let mut args = Vec::new_in(&self.alloc); for &mut id in reverse_postorder { - for (local, constant) in + for (local, value) in propagate_block_params(&mut args, body, id, |operand| visitor.try_eval(operand)) { - visitor.constants.insert(local, constant); + visitor.values.insert(local, value); } Ok(()) = @@ -224,28 +258,27 @@ impl<'env, 'heap, A: ResetAllocator> TransformPass<'env, 'heap> for CopyPropagat struct CopyPropagationVisitor<'env, 'heap, A: Allocator> { interner: &'env Interner<'heap>, - constants: LocalVec>, A>, + values: LocalVec>, A>, changed: bool, } impl<'heap, A: Allocator> CopyPropagationVisitor<'_, 'heap, A> { /// Attempts to evaluate an operand to a known constant or classify it for simplification. - /// - /// Returns `Int` if the operand is a constant integer or a local known to hold one, - /// `Place` if it's a non-constant place, or `Other` for operands that can't be simplified. - fn try_eval(&self, operand: Operand<'heap>) -> Option> { - if let Operand::Constant(constant) = operand { - return Some(constant); + fn try_eval(&self, operand: Operand<'heap>) -> Option> { + let place = match operand { + Operand::Place(place) => place, + Operand::Constant(constant) => return Some(KnownValue::Constant(constant)), + }; + + if !place.projections.is_empty() { + return None; } - if let Operand::Place(place) = operand - && place.projections.is_empty() - && let Some(&constant) = self.constants.lookup(place.local) - { - return Some(constant); + if let Some(&known) = self.values.lookup(place.local) { + return Some(known); } - None + Some(KnownValue::Local(place.local)) } } @@ -262,12 +295,10 @@ impl<'heap, A: Allocator> VisitorMut<'heap> for CopyPropagationVisitor<'_, 'heap } fn visit_operand(&mut self, _: Location, operand: &mut Operand<'heap>) -> Self::Result<()> { - if let Operand::Place(place) = operand - && place.projections.is_empty() - && let Some(&constant) = self.constants.lookup(place.local) - { - *operand = Operand::Constant(constant); - self.changed = true; + if let Some(known) = self.try_eval(*operand) { + let known: Operand<'heap> = known.into(); + self.changed |= known != *operand; + *operand = known; } Ok(()) @@ -292,16 +323,8 @@ impl<'heap, A: Allocator> VisitorMut<'heap> for CopyPropagationVisitor<'_, 'heap return Ok(()); }; - match load { - Operand::Place(place) if place.projections.is_empty() => { - if let Some(&constant) = self.constants.lookup(place.local) { - self.constants.insert(lhs.local, constant); - } - } - Operand::Place(_) => {} - &mut Operand::Constant(constant) => { - self.constants.insert(lhs.local, constant); - } + if let Some(known) = self.try_eval(*load) { + self.values.insert(lhs.local, known); } Ok(()) diff --git a/libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs b/libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs index 73672a127e4..862349912a8 100644 --- a/libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs @@ -377,3 +377,138 @@ fn loop_back_edge() { }, ); } + +/// Tests simple copy propagation: `_2 = _1; use(_2)` → `use(_1)`. +#[test] +fn simple_copy() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Bool { + decl x: Int, y: Int, r: Bool; + + bb0() { + y = load x; + r = bin.== y y; + return r; + } + }); + + assert_cp_pass( + "simple_copy", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests chained copy propagation: `_2 = _1; _3 = _2; use(_3)` → `use(_1)`. +#[test] +fn copy_chain() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Bool { + decl x: Int, y: Int, z: Int, r: Bool; + + bb0() { + y = load x; + z = load y; + r = bin.== z z; + return r; + } + }); + + assert_cp_pass( + "copy_chain", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests block parameter propagation when all predecessors pass the same local (copy). +#[test] +fn block_param_copy() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Bool { + decl x: Int, cond: Bool, p: Int, r: Bool; + + bb0() { + cond = load true; + if cond then bb1() else bb2(); + }, + bb1() { + goto bb3(x); + }, + bb2() { + goto bb3(x); + }, + bb3(p) { + r = bin.== p p; + return r; + } + }); + + assert_cp_pass( + "block_param_copy", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests that block parameters are not propagated when predecessors pass different locals. +#[test] +fn block_param_copy_disagreement() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Bool { + decl x: Int, y: Int, cond: Bool, p: Int, r: Bool; + + bb0() { + cond = load true; + if cond then bb1() else bb2(); + }, + bb1() { + goto bb3(x); + }, + bb2() { + goto bb3(y); + }, + bb3(p) { + r = bin.== p p; + return r; + } + }); + + assert_cp_pass( + "block_param_copy_disagreement", + body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} diff --git a/libs/@local/hashql/mir/src/pass/transform/ssa_repair/mod.rs b/libs/@local/hashql/mir/src/pass/transform/ssa_repair/mod.rs index 6297b55811b..fbf17cd46f1 100644 --- a/libs/@local/hashql/mir/src/pass/transform/ssa_repair/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/ssa_repair/mod.rs @@ -593,7 +593,7 @@ impl<'heap> VisitorMut<'heap> for RewireBody<'_, 'heap> { // Sanity check to ensure that our previous analysis step isn't divergent debug_assert_eq!(Some(current_local), self.last_def); - let operand = Operand::Place(Place::local(current_local, self.interner)); + let operand = Operand::Place(Place::local(current_local)); let mut args = TinyVec::from_slice_copy(&target.args); args.push(operand); diff --git a/libs/@local/hashql/mir/src/reify/current.rs b/libs/@local/hashql/mir/src/reify/current.rs index 9d3440ae5e8..f033cc15721 100644 --- a/libs/@local/hashql/mir/src/reify/current.rs +++ b/libs/@local/hashql/mir/src/reify/current.rs @@ -2,6 +2,7 @@ use core::mem; use hashql_core::{ heap::{self, Heap}, + intern::Interned, span::SpanId, }; @@ -75,11 +76,11 @@ pub(crate) struct CurrentBlock<'mir, 'heap> { } impl<'mir, 'heap> CurrentBlock<'mir, 'heap> { - pub(crate) fn new(heap: &'heap Heap, interner: &'mir Interner<'heap>) -> Self { + pub(crate) const fn new(heap: &'heap Heap, interner: &'mir Interner<'heap>) -> Self { Self { heap, interner, - block: Self::empty_block(heap, interner), + block: Self::empty_block(heap), slot: None, entry: None, forward_ref: Vec::new(), @@ -95,9 +96,9 @@ impl<'mir, 'heap> CurrentBlock<'mir, 'heap> { self.block.statements.push(statement); } - fn empty_block(heap: &'heap Heap, interner: &Interner<'heap>) -> BasicBlock<'heap> { + const fn empty_block(heap: &'heap Heap) -> BasicBlock<'heap> { BasicBlock { - params: interner.locals.intern_slice(&[]), + params: Interned::empty(), statements: heap::Vec::new_in(heap), // This terminator is temporary and is going to get replaced once finished terminator: Terminator { @@ -165,7 +166,7 @@ impl<'mir, 'heap> CurrentBlock<'mir, 'heap> { } pub(crate) fn reserve(&mut self, blocks: &mut BasicBlockVec, &'heap Heap>) { - self.slot = Some(blocks.push(Self::empty_block(self.heap, self.interner))); + self.slot = Some(blocks.push(Self::empty_block(self.heap))); } pub(crate) fn terminate( @@ -175,7 +176,7 @@ impl<'mir, 'heap> CurrentBlock<'mir, 'heap> { blocks: &mut BasicBlockVec, &'heap Heap>, ) -> ExitBlock { // Finishes the current block, and starts a new one - let previous = mem::replace(&mut self.block, Self::empty_block(self.heap, self.interner)); + let previous = mem::replace(&mut self.block, Self::empty_block(self.heap)); let (_, id) = Self::complete( previous, terminator, diff --git a/libs/@local/hashql/mir/src/reify/mod.rs b/libs/@local/hashql/mir/src/reify/mod.rs index 3e5532869cc..2561add3950 100644 --- a/libs/@local/hashql/mir/src/reify/mod.rs +++ b/libs/@local/hashql/mir/src/reify/mod.rs @@ -246,7 +246,7 @@ impl<'ctx, 'mir, 'hir, 'env, 'heap> Reifier<'ctx, 'mir, 'hir, 'env, 'heap> { block.push_statement(Statement { span, kind: StatementKind::Assign(Assign { - lhs: Place::local(local, self.context.mir.interner), + lhs: Place::local(local), rhs: RValue::Load(Operand::Place(Place { local: env, projections: self.context.mir.interner.projections.intern_slice(&[ @@ -372,13 +372,12 @@ impl<'ctx, 'mir, 'hir, 'env, 'heap> Reifier<'ctx, 'mir, 'hir, 'env, 'heap> { r#type: closure_type.returns, name: None, }); - let lhs = Place::local(output, this.context.mir.interner); + let lhs = Place::local(output); let operand = if let Some(param) = param { Operand::Place(Place::local( this.locals[param.id] .unwrap_or_else(|| unreachable!("We just verified this local exists")), - this.context.mir.interner, )) } else { Operand::Constant(Constant::Unit) diff --git a/libs/@local/hashql/mir/src/reify/rvalue.rs b/libs/@local/hashql/mir/src/reify/rvalue.rs index 71e4818c38f..7c88353c142 100644 --- a/libs/@local/hashql/mir/src/reify/rvalue.rs +++ b/libs/@local/hashql/mir/src/reify/rvalue.rs @@ -215,7 +215,7 @@ impl<'mir, 'heap> Reifier<'_, 'mir, '_, '_, 'heap> { .push(fat_call_on_constant(function.span)); // Return a bogus value / place that can be used to continue lowering - Place::local(Local::MAX, self.context.mir.interner) + Place::local(Local::MAX) } }; @@ -293,7 +293,7 @@ impl<'mir, 'heap> Reifier<'_, 'mir, '_, '_, 'heap> { // that are referenced out of scope (upvars). let mut closure_operands = IdVec::with_capacity_in(2, self.context.mir.heap); closure_operands.push(Operand::Constant(Constant::FnPtr(ptr))); - closure_operands.push(Operand::Place(Place::local(env, self.context.mir.interner))); + closure_operands.push(Operand::Place(Place::local(env))); RValue::Aggregate(Aggregate { kind: AggregateKind::Closure, diff --git a/libs/@local/hashql/mir/src/reify/terminator.rs b/libs/@local/hashql/mir/src/reify/terminator.rs index 4d3639681f0..6e0a1915819 100644 --- a/libs/@local/hashql/mir/src/reify/terminator.rs +++ b/libs/@local/hashql/mir/src/reify/terminator.rs @@ -159,8 +159,8 @@ impl<'mir, 'heap> Reifier<'_, 'mir, '_, '_, 'heap> { discriminant: test, targets: SwitchTargets::new_if( self.context.mir.heap, - Target::block(then_entry, self.context.mir.interner), - Target::block(else_entry, self.context.mir.interner), + Target::block(then_entry.into()), + Target::block(else_entry.into()), ), }), }, diff --git a/libs/@local/hashql/mir/src/reify/transform.rs b/libs/@local/hashql/mir/src/reify/transform.rs index 69cfc8c2670..d0d0f0695c2 100644 --- a/libs/@local/hashql/mir/src/reify/transform.rs +++ b/libs/@local/hashql/mir/src/reify/transform.rs @@ -69,10 +69,7 @@ impl<'mir, 'heap> Reifier<'_, 'mir, '_, '_, 'heap> { continue; }; - tuple_elements.push(Operand::Place(Place::local( - capture_local, - self.context.mir.interner, - ))); + tuple_elements.push(Operand::Place(Place::local(capture_local))); tuple_element_ty.push(self.local_decls[capture_local].r#type); } @@ -88,7 +85,7 @@ impl<'mir, 'heap> Reifier<'_, 'mir, '_, '_, 'heap> { block.push_statement(Statement { span: hir.span, kind: StatementKind::Assign(Assign { - lhs: Place::local(env_local, self.context.mir.interner), + lhs: Place::local(env_local), rhs: RValue::Aggregate(Aggregate { kind: AggregateKind::Tuple, operands: tuple_elements, @@ -126,7 +123,7 @@ impl<'mir, 'heap> Reifier<'_, 'mir, '_, '_, 'heap> { block.push_statement(Statement { span: binding.span, kind: StatementKind::Assign(Assign { - lhs: Place::local(local, self.context.mir.interner), + lhs: Place::local(local), rhs: rvalue, }), }); diff --git a/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/block_param_copy.snap b/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/block_param_copy.snap new file mode 100644 index 00000000000..59e943e3318 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/block_param_copy.snap @@ -0,0 +1,59 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs +expression: value +--- +fn {closure@4294967040}() -> Boolean { + let %0: Integer + let %1: Boolean + let %2: Integer + let %3: Boolean + + bb0(): { + %1 = 1 + + switchInt(%1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + goto -> bb3(%0) + } + + bb3(%2): { + %3 = %2 == %2 + + return %3 + } +} + +================== Changed: Yes ================== + +fn {closure@4294967040}() -> Boolean { + let %0: Integer + let %1: Boolean + let %2: Integer + let %3: Boolean + + bb0(): { + %1 = 1 + + switchInt(1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + goto -> bb3(%0) + } + + bb3(%2): { + %3 = %0 == %0 + + return %3 + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/block_param_copy_disagreement.snap b/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/block_param_copy_disagreement.snap new file mode 100644 index 00000000000..860c8f29c6d --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/block_param_copy_disagreement.snap @@ -0,0 +1,61 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs +expression: value +--- +fn {closure@4294967040}() -> Boolean { + let %0: Integer + let %1: Integer + let %2: Boolean + let %3: Integer + let %4: Boolean + + bb0(): { + %2 = 1 + + switchInt(%2) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + goto -> bb3(%1) + } + + bb3(%3): { + %4 = %3 == %3 + + return %4 + } +} + +================== Changed: Yes ================== + +fn {closure@4294967040}() -> Boolean { + let %0: Integer + let %1: Integer + let %2: Boolean + let %3: Integer + let %4: Boolean + + bb0(): { + %2 = 1 + + switchInt(1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + goto -> bb3(%1) + } + + bb3(%3): { + %4 = %3 == %3 + + return %4 + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/copy_chain.snap b/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/copy_chain.snap new file mode 100644 index 00000000000..d3c0977d4ea --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/copy_chain.snap @@ -0,0 +1,35 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs +expression: value +--- +fn {closure@4294967040}() -> Boolean { + let %0: Integer + let %1: Integer + let %2: Integer + let %3: Boolean + + bb0(): { + %1 = %0 + %2 = %1 + %3 = %2 == %2 + + return %3 + } +} + +================== Changed: Yes ================== + +fn {closure@4294967040}() -> Boolean { + let %0: Integer + let %1: Integer + let %2: Integer + let %3: Boolean + + bb0(): { + %1 = %0 + %2 = %0 + %3 = %0 == %0 + + return %3 + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/simple_copy.snap b/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/simple_copy.snap new file mode 100644 index 00000000000..4ab860f1ba9 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/copy_propagation/simple_copy.snap @@ -0,0 +1,31 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs +expression: value +--- +fn {closure@4294967040}() -> Boolean { + let %0: Integer + let %1: Integer + let %2: Boolean + + bb0(): { + %1 = %0 + %2 = %1 == %1 + + return %2 + } +} + +================== Changed: Yes ================== + +fn {closure@4294967040}() -> Boolean { + let %0: Integer + let %1: Integer + let %2: Boolean + + bb0(): { + %1 = %0 + %2 = %0 == %0 + + return %2 + } +}