diff --git a/.claude/skills/testing-hashql/references/mir-builder-guide.md b/.claude/skills/testing-hashql/references/mir-builder-guide.md index 3e15a4256ec..1967b7e4af7 100644 --- a/.claude/skills/testing-hashql/references/mir-builder-guide.md +++ b/.claude/skills/testing-hashql/references/mir-builder-guide.md @@ -52,6 +52,18 @@ body!(interner, env; @ / -> { }) ``` +**Important:** Only a single `decl` statement is supported. Declare all locals in one comma-separated list: + +```rust +// ✅ Correct - single decl with all locals +decl env: (), vertex: Entity, x: Int, y: Int, result: Bool; + +// ❌ Wrong - multiple decl statements will not compile +decl env: (), vertex: Entity; +decl x: Int, y: Int; +decl result: Bool; +``` + ### Header | Component | Description | Example | @@ -87,6 +99,7 @@ The `` can be a numeric literal (`0`, `1`, `42`) or a variable identifier (` | `(a: T1, b: T2)` | Struct types | `(a: Int, b: Bool)` | | `[List T]` | List type (intrinsic) | `[List Int]`, `[List (Int, Bool)]` | | `[fn(T1, T2) -> R]` | Closure types | `[fn(Int) -> Int]`, `[fn() -> Bool]` | +| `[Opaque path; T]` | Opaque type with symbol path | `[Opaque sym::path::Entity; ?]` | | `\|types\| types.custom()` | Custom type expression | `\|t\| t.null()` | ### Projections (Optional) @@ -97,12 +110,18 @@ Declare field projections after `decl` to access struct/tuple fields as places: @proj = .: , ...; ``` -Supports nested projections: +**Field access modes:** + +- Numeric index (e.g., `tup.0`) → `ProjectionKind::Field` +- Named field (e.g., `entity.metadata`) → `ProjectionKind::FieldByName` + +Each `@proj` declaration supports only ONE field after the base. For deeper paths, chain through intermediate declarations: ```rust let body = body!(interner, env; fn@0/0 -> Int { decl tup: ((Int, Int), Int), result: Int; - @proj inner = tup.0: (Int, Int), inner_1 = tup.0.1: Int; + // inner uses tup as base, inner_1 uses inner as base + @proj inner = tup.0: (Int, Int), inner_1 = inner.1: Int; bb0() { result = load inner_1; @@ -111,6 +130,22 @@ let body = body!(interner, env; fn@0/0 -> Int { }); ``` +Named field projections for opaque types: + +```rust +use hashql_core::symbol::sym; + +let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (), vertex: [Opaque sym::path::Entity; ?]; + // Chain: vertex -> metadata -> archived + @proj metadata = vertex.metadata: ?, archived = metadata.archived: Bool; + + bb0() { + return archived; + } +}); +``` + ### Statements | Syntax | Description | MIR Equivalent | diff --git a/AGENTS.md b/AGENTS.md index 96e48e9363a..62096cc9c5a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -84,7 +84,7 @@ For Rust packages, you can add features as needed with `--all-features`, specifi CRITICAL: For the files referenced below, use your Read tool to load it on a need-to-know basis, ONLY when relevant to the SPECIFIC task at hand: -- @.config/agents/rules/*.md +- .config/agents/rules/*.md Instructions: diff --git a/Cargo.toml b/Cargo.toml index f9028196215..ed2056809ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -93,7 +93,6 @@ hashql-syntax-jexpr.path = "libs/@local/hashql/syntax-jexpr" type-system.path = "libs/@blockprotocol/type-system/rust" # External dependencies -allocator-api2 = { version = "0.2.8", default-features = false } annotate-snippets = { version = "0.12.8", default-features = false } ansi-to-html = { version = "0.2.2", default-features = false } anstream = { version = "0.6.21", default-features = false } diff --git a/libs/@local/hashql/compiletest/src/suite/mir_reify.rs b/libs/@local/hashql/compiletest/src/suite/mir_reify.rs index 1dc21a76ae7..c94207f534f 100644 --- a/libs/@local/hashql/compiletest/src/suite/mir_reify.rs +++ b/libs/@local/hashql/compiletest/src/suite/mir_reify.rs @@ -20,7 +20,7 @@ use hashql_mir::{ context::MirContext, def::{DefId, DefIdSlice, DefIdVec}, intern::Interner, - pretty::{D2Buffer, D2Format, TextFormat}, + pretty::{D2Buffer, D2Format, TextFormatOptions}, }; use super::{RunContext, Suite, SuiteDiagnostic, SuiteDirectives, common::process_status}; @@ -87,12 +87,15 @@ pub(crate) fn mir_format_text<'heap>( TypeFormatterOptions::terse().with_qualified_opaque_names(true), ); - let mut text_format = TextFormat { + let mut text_format = TextFormatOptions { writer, indent: 4, sources: bodies, types, - }; + annotations: (), + } + .build(); + text_format .format(bodies, &[root]) .expect("should be able to write bodies"); diff --git a/libs/@local/hashql/core/src/heap/mod.rs b/libs/@local/hashql/core/src/heap/mod.rs index f2f33329d28..23c30e7c6d1 100644 --- a/libs/@local/hashql/core/src/heap/mod.rs +++ b/libs/@local/hashql/core/src/heap/mod.rs @@ -256,7 +256,7 @@ impl Heap { strings.reserve(TABLES.iter().map(|table| table.len()).sum()); for &table in TABLES { - for &symbol in table { + for symbol in table { assert!(strings.insert(symbol.as_str())); } } diff --git a/libs/@local/hashql/core/src/id/bit_vec/mod.rs b/libs/@local/hashql/core/src/id/bit_vec/mod.rs index 4764b828ca3..1f73205001a 100644 --- a/libs/@local/hashql/core/src/id/bit_vec/mod.rs +++ b/libs/@local/hashql/core/src/id/bit_vec/mod.rs @@ -228,6 +228,15 @@ impl DenseBitSet { new_word != word } + #[inline] + pub fn set(&mut self, elem: T, value: bool) -> bool { + if value { + self.insert(elem) + } else { + self.remove(elem) + } + } + #[inline] pub fn insert_range(&mut self, elems: impl RangeBounds) { let Some((start, end)) = inclusive_start_end(elems, self.domain_size) else { diff --git a/libs/@local/hashql/core/src/module/std_lib/graph/types/knowledge/entity.rs b/libs/@local/hashql/core/src/module/std_lib/graph/types/knowledge/entity.rs index 664e6965ab5..cf430046c71 100644 --- a/libs/@local/hashql/core/src/module/std_lib/graph/types/knowledge/entity.rs +++ b/libs/@local/hashql/core/src/module/std_lib/graph/types/knowledge/entity.rs @@ -4,7 +4,7 @@ use crate::{ StandardLibrary, std_lib::{self, ItemDef, ModuleDef, StandardLibraryModule, core::option::option}, }, - symbol::Symbol, + symbol::{Symbol, sym}, }; pub(in crate::module::std_lib) struct Entity { @@ -106,7 +106,7 @@ impl<'heap> StandardLibraryModule<'heap> for Entity { let entity_ty = lib.ty.generic( [t_arg], lib.ty.opaque( - "::graph::types::knowledge::entity::Entity", + sym::path::Entity, lib.ty.r#struct([ ("id", entity_record_id_ty), ("properties", t_param), diff --git a/libs/@local/hashql/core/src/symbol/sym.rs b/libs/@local/hashql/core/src/symbol/sym.rs index 226f2f1d179..e276aa5b355 100644 --- a/libs/@local/hashql/core/src/symbol/sym.rs +++ b/libs/@local/hashql/core/src/symbol/sym.rs @@ -61,67 +61,77 @@ macro_rules! symbols { } symbols![lexical; LEXICAL; - BaseUrl, - Boolean, - Dict, - E, - Err, - Integer, - Intersection, - List, - Never, - None, - Null, - Number, - Ok, - R, - Result, - Some, - String, - T, - U, - Union, - Unknown, - Url, access, add, and, + archived, + archived_by_id, + bar, + BaseUrl, bit_and, bit_not, bit_or, bit_shl, bit_shr, bit_xor, + Boolean, collect, + confidence, core, + created_at_decision_time, + created_at_transaction_time, + created_by_id, + decision_time, + Dict, div, draft_id, + E, + edition, + edition_id, + encodings, entity, entity_edition_id, entity_id, + entity_type_ids, entity_uuid, eq, + Err, filter, + foo, gt, gte, id, index, + inferred, input, input_exists: "$exists", + Integer, + Intersection, kernel, + left_entity_confidence, left_entity_id, + left_entity_provenance, link_data, + List, lt, lte, math, + metadata, mul, ne, + Never, + None, not, + Null, null, + Number, + Ok, option, or, pow, properties, + provenance, + provided, r#as: "as", r#as_force: "as!", r#else: "else", @@ -136,11 +146,27 @@ symbols![lexical; LEXICAL; r#true: "true", r#type: "type", r#use: "use", + R, + record_id, + Result, + right_entity_confidence, right_entity_id, + right_entity_provenance, + Some, special_form, + String, sub, + T, + temporal_versioning, then: "then", thunk: "thunk", + transaction_time, + U, + Union, + Unknown, + unknown, + Url, + vectors, web_id, ]; @@ -202,6 +228,7 @@ symbols![path; PATHS; graph_head_entities: "::graph::head::entities", graph_body_filter: "::graph::body::filter", graph_tail_collect: "::graph::tail::collect", + Entity: "::graph::types::knowledge::entity::Entity" ]; pub(crate) const TABLES: &[&[&Symbol<'static>]] = &[LEXICAL, DIGITS, SYMBOLS, PATHS, INTERNAL]; diff --git a/libs/@local/hashql/core/src/type/inference/visit.rs b/libs/@local/hashql/core/src/type/inference/visit.rs index f1f9e0e0db5..0f43e7d2887 100644 --- a/libs/@local/hashql/core/src/type/inference/visit.rs +++ b/libs/@local/hashql/core/src/type/inference/visit.rs @@ -66,61 +66,67 @@ impl<'env, 'heap> VariableDependencyCollector<'env, 'heap> { impl<'heap> Visitor<'heap> for VariableDependencyCollector<'_, 'heap> { type Filter = VariableVisitorFilter; + type Result = Result<(), !>; fn env(&self) -> &Environment<'heap> { self.env } - fn visit_type(&mut self, r#type: Type<'heap>) { + fn visit_type(&mut self, r#type: Type<'heap>) -> Self::Result { if self.recursion.enter(r#type, r#type).is_break() { // recursive type definition - return; + return Ok(()); } let previous = self.current_span; self.current_span = r#type.span; - visit::walk_type(self, r#type); + Ok(()) = visit::walk_type(self, r#type); self.current_span = previous; self.recursion.exit(r#type, r#type); + Ok(()) } - fn visit_generic_argument(&mut self, argument: GenericArgument<'heap>) { + fn visit_generic_argument(&mut self, argument: GenericArgument<'heap>) -> Self::Result { // We only depend on the introduced variable, but **not** the constraint itself, therefore // we don't walk the argument. self.variables.push(Variable { span: self.current_span, kind: VariableKind::Generic(argument.id), }); + Ok(()) } - fn visit_generic_substitution(&mut self, substitution: GenericSubstitution) { + fn visit_generic_substitution(&mut self, substitution: GenericSubstitution) -> Self::Result { // We only depend on the introduced variable, but **not** the constraint itself, therefore // we don't walk the substitution. self.variables.push(Variable { span: self.current_span, kind: VariableKind::Generic(substitution.argument), }); + Ok(()) } - fn visit_param(&mut self, param: Type<'heap, Param>) { - visit::walk_param(self, param); + fn visit_param(&mut self, param: Type<'heap, Param>) -> Self::Result { + Ok(()) = visit::walk_param(self, param); self.variables.push(Variable { span: param.span, kind: VariableKind::Generic(param.kind.argument), }); + Ok(()) } - fn visit_infer(&mut self, infer: Type<'heap, Infer>) { - visit::walk_infer(self, infer); + fn visit_infer(&mut self, infer: Type<'heap, Infer>) -> Self::Result { + Ok(()) = visit::walk_infer(self, infer); self.variables.push(Variable { span: infer.span, kind: VariableKind::Hole(infer.kind.hole), }); + Ok(()) } } @@ -149,42 +155,46 @@ impl<'env, 'heap> VariableCollector<'env, 'heap> { impl<'heap> Visitor<'heap> for VariableCollector<'_, 'heap> { type Filter = filter::Deep; + type Result = Result<(), !>; fn env(&self) -> &Environment<'heap> { self.env } - fn visit_type(&mut self, r#type: Type<'heap>) { + fn visit_type(&mut self, r#type: Type<'heap>) -> Self::Result { if self.recursion.enter(r#type, r#type).is_break() { // recursive type definition - return; + return Ok(()); } let previous = self.current_span; self.current_span = r#type.span; - visit::walk_type(self, r#type); + Ok(()) = visit::walk_type(self, r#type); self.current_span = previous; self.recursion.exit(r#type, r#type); + Ok(()) } - fn visit_param(&mut self, param: Type<'heap, Param>) { - visit::walk_param(self, param); + fn visit_param(&mut self, param: Type<'heap, Param>) -> Self::Result { + Ok(()) = visit::walk_param(self, param); self.variables.push(Variable { span: param.span, kind: VariableKind::Generic(param.kind.argument), }); + Ok(()) } - fn visit_infer(&mut self, infer: Type<'heap, Infer>) { - visit::walk_infer(self, infer); + fn visit_infer(&mut self, infer: Type<'heap, Infer>) -> Self::Result { + Ok(()) = visit::walk_infer(self, infer); self.variables.push(Variable { span: infer.span, kind: VariableKind::Hole(infer.kind.hole), }); + Ok(()) } } diff --git a/libs/@local/hashql/core/src/type/visit.rs b/libs/@local/hashql/core/src/type/visit.rs index d75b8b1ecf5..bf4ad690a2c 100644 --- a/libs/@local/hashql/core/src/type/visit.rs +++ b/libs/@local/hashql/core/src/type/visit.rs @@ -1,3 +1,5 @@ +use core::ops::Try; + use self::filter::{Deep, Filter as _}; use super::{ Type, TypeId, @@ -9,6 +11,7 @@ use super::{ intrinsic::{DictType, ListType}, r#struct::{StructField, StructFields}, }, + recursion::RecursionBoundary, }; pub mod filter { @@ -104,6 +107,12 @@ pub mod filter { } } +macro_rules! Ok { + () => { + Try::from_output(()) + }; +} + /// A visitor for traversing and analyzing the type system. /// /// To implement a custom type visitor, create a type that implements this trait @@ -132,141 +141,155 @@ pub mod filter { /// recursion. pub trait Visitor<'heap> { type Filter: filter::Filter = Deep; + type Result: Try; fn env(&self) -> &Environment<'heap>; - fn visit_generic_arguments(&mut self, arguments: GenericArguments<'heap>) { - walk_generic_arguments(self, arguments); + fn visit_generic_arguments(&mut self, arguments: GenericArguments<'heap>) -> Self::Result { + walk_generic_arguments(self, arguments) } - fn visit_generic_argument(&mut self, argument: GenericArgument<'heap>) { - walk_generic_argument(self, argument); + fn visit_generic_argument(&mut self, argument: GenericArgument<'heap>) -> Self::Result { + walk_generic_argument(self, argument) } - fn visit_generic_substitutions(&mut self, substitutions: GenericSubstitutions<'heap>) { - walk_generic_substitutions(self, substitutions); + fn visit_generic_substitutions( + &mut self, + substitutions: GenericSubstitutions<'heap>, + ) -> Self::Result { + walk_generic_substitutions(self, substitutions) } - fn visit_generic_substitution(&mut self, substitution: GenericSubstitution) { - walk_generic_substitution(self, substitution); + fn visit_generic_substitution(&mut self, substitution: GenericSubstitution) -> Self::Result { + walk_generic_substitution(self, substitution) } - fn visit_id(&mut self, id: TypeId) { - walk_id(self, id); + fn visit_id(&mut self, id: TypeId) -> Self::Result { + walk_id(self, id) } - fn visit_type(&mut self, r#type: Type<'heap>) { - walk_type(self, r#type); + fn visit_type(&mut self, r#type: Type<'heap>) -> Self::Result { + walk_type(self, r#type) } - fn visit_opaque(&mut self, opaque: Type<'heap, OpaqueType>) { - walk_opaque(self, opaque); + fn visit_opaque(&mut self, opaque: Type<'heap, OpaqueType>) -> Self::Result { + walk_opaque(self, opaque) } #[expect(unused_variables, reason = "trait definition")] - fn visit_primitive(&mut self, primitive: Type<'heap, PrimitiveType>) { + fn visit_primitive(&mut self, primitive: Type<'heap, PrimitiveType>) -> Self::Result { // Do nothing, there's nothing to walk + Ok!() } - fn visit_intrinsic_list(&mut self, list: Type<'heap, ListType>) { - walk_intrinsic_list(self, list); + fn visit_intrinsic_list(&mut self, list: Type<'heap, ListType>) -> Self::Result { + walk_intrinsic_list(self, list) } - fn visit_intrinsic_dict(&mut self, dict: Type<'heap, DictType>) { - walk_intrinsic_dict(self, dict); + fn visit_intrinsic_dict(&mut self, dict: Type<'heap, DictType>) -> Self::Result { + walk_intrinsic_dict(self, dict) } - fn visit_intrinsic(&mut self, intrinsic: Type<'heap, IntrinsicType>) { - walk_intrinsic(self, intrinsic); + fn visit_intrinsic(&mut self, intrinsic: Type<'heap, IntrinsicType>) -> Self::Result { + walk_intrinsic(self, intrinsic) } - fn visit_struct(&mut self, r#struct: Type<'heap, StructType>) { - walk_struct(self, r#struct); + fn visit_struct(&mut self, r#struct: Type<'heap, StructType>) -> Self::Result { + walk_struct(self, r#struct) } - fn visit_struct_fields(&mut self, fields: StructFields<'heap>) { - walk_struct_fields(self, fields); + fn visit_struct_fields(&mut self, fields: StructFields<'heap>) -> Self::Result { + walk_struct_fields(self, fields) } - fn visit_struct_field(&mut self, field: StructField<'heap>) { - walk_struct_field(self, field); + fn visit_struct_field(&mut self, field: StructField<'heap>) -> Self::Result { + walk_struct_field(self, field) } - fn visit_tuple(&mut self, tuple: Type<'heap, TupleType>) { - walk_tuple(self, tuple); + fn visit_tuple(&mut self, tuple: Type<'heap, TupleType>) -> Self::Result { + walk_tuple(self, tuple) } - fn visit_union(&mut self, union: Type<'heap, UnionType>) { - walk_union(self, union); + fn visit_union(&mut self, union: Type<'heap, UnionType>) -> Self::Result { + walk_union(self, union) } - fn visit_intersection(&mut self, intersection: Type<'heap, IntersectionType>) { - walk_intersection(self, intersection); + fn visit_intersection(&mut self, intersection: Type<'heap, IntersectionType>) -> Self::Result { + walk_intersection(self, intersection) } - fn visit_closure(&mut self, closure: Type<'heap, ClosureType>) { - walk_closure(self, closure); + fn visit_closure(&mut self, closure: Type<'heap, ClosureType>) -> Self::Result { + walk_closure(self, closure) } - fn visit_apply(&mut self, apply: Type<'heap, Apply>) { - walk_apply(self, apply); + fn visit_apply(&mut self, apply: Type<'heap, Apply>) -> Self::Result { + walk_apply(self, apply) } - fn visit_generic(&mut self, generic: Type<'heap, Generic<'heap>>) { - walk_generic(self, generic); + fn visit_generic(&mut self, generic: Type<'heap, Generic<'heap>>) -> Self::Result { + walk_generic(self, generic) } - fn visit_param(&mut self, param: Type<'heap, Param>) { - walk_param(self, param); + fn visit_param(&mut self, param: Type<'heap, Param>) -> Self::Result { + walk_param(self, param) } - fn visit_infer(&mut self, infer: Type<'heap, Infer>) { - walk_infer(self, infer); + fn visit_infer(&mut self, infer: Type<'heap, Infer>) -> Self::Result { + walk_infer(self, infer) } } pub fn walk_generic_arguments<'heap, V: Visitor<'heap> + ?Sized>( visitor: &mut V, generic_arguments: GenericArguments<'heap>, -) { +) -> V::Result { for &generic_argument in generic_arguments.iter() { - visitor.visit_generic_argument(generic_argument); + visitor.visit_generic_argument(generic_argument)?; } + + Ok!() } pub fn walk_generic_argument<'heap, V: Visitor<'heap> + ?Sized>( visitor: &mut V, generic_argument: GenericArgument<'heap>, -) { +) -> V::Result { if let Some(constraint) = generic_argument.constraint { - visitor.visit_id(constraint); + visitor.visit_id(constraint)?; } + + Ok!() } pub fn walk_generic_substitutions<'heap, V: Visitor<'heap> + ?Sized>( visitor: &mut V, substitutions: GenericSubstitutions<'heap>, -) { +) -> V::Result { for &substitution in substitutions.iter() { - visitor.visit_generic_substitution(substitution); + visitor.visit_generic_substitution(substitution)?; } + + Ok!() } pub fn walk_generic_substitution<'heap, V: Visitor<'heap> + ?Sized>( visitor: &mut V, GenericSubstitution { argument: _, value }: GenericSubstitution, -) { - visitor.visit_id(value); +) -> V::Result { + visitor.visit_id(value)?; + + Ok!() } -pub fn walk_id<'heap, V: Visitor<'heap> + ?Sized>(visitor: &mut V, id: TypeId) { +pub fn walk_id<'heap, V: Visitor<'heap> + ?Sized>(visitor: &mut V, id: TypeId) -> V::Result { if !V::Filter::DEEP { - return; + return Ok!(); } let r#type = visitor.env().r#type(id); - visitor.visit_type(r#type); + visitor.visit_type(r#type)?; + Ok!() } pub fn walk_type<'heap, V: Visitor<'heap> + ?Sized>( @@ -276,7 +299,7 @@ pub fn walk_type<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind, }: Type<'heap>, -) { +) -> V::Result { match kind { TypeKind::Opaque(opaque) => visitor.visit_opaque(r#type.with(opaque)), TypeKind::Primitive(primitive) => visitor.visit_primitive(r#type.with(primitive)), @@ -285,14 +308,14 @@ pub fn walk_type<'heap, V: Visitor<'heap> + ?Sized>( TypeKind::Tuple(tuple) => visitor.visit_tuple(r#type.with(tuple)), TypeKind::Union(union) => visitor.visit_union(r#type.with(union)), TypeKind::Intersection(intersection) => { - visitor.visit_intersection(r#type.with(intersection)); + visitor.visit_intersection(r#type.with(intersection)) } TypeKind::Closure(closure) => visitor.visit_closure(r#type.with(closure)), TypeKind::Apply(apply) => visitor.visit_apply(r#type.with(apply)), TypeKind::Generic(generic) => visitor.visit_generic(r#type.with(generic)), TypeKind::Param(param) => visitor.visit_param(r#type.with(param)), TypeKind::Infer(infer) => visitor.visit_infer(r#type.with(infer)), - TypeKind::Never | TypeKind::Unknown => {} + TypeKind::Never | TypeKind::Unknown => Ok!(), } } @@ -303,10 +326,12 @@ pub fn walk_opaque<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind: &OpaqueType { name: _, repr }, }: Type<'heap, OpaqueType>, -) { +) -> V::Result { if V::Filter::GENERIC_PARAMETERS { - visitor.visit_id(repr); + visitor.visit_id(repr)?; } + + Ok!() } pub fn walk_intrinsic_list<'heap, V: Visitor<'heap> + ?Sized>( @@ -316,10 +341,12 @@ pub fn walk_intrinsic_list<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind: &ListType { element }, }: Type<'heap, ListType>, -) { +) -> V::Result { if V::Filter::GENERIC_PARAMETERS { - visitor.visit_id(element); + visitor.visit_id(element)?; } + + Ok!() } pub fn walk_intrinsic_dict<'heap, V: Visitor<'heap> + ?Sized>( @@ -329,11 +356,13 @@ pub fn walk_intrinsic_dict<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind: &DictType { key, value }, }: Type<'heap, DictType>, -) { +) -> V::Result { if V::Filter::GENERIC_PARAMETERS { - visitor.visit_id(key); - visitor.visit_id(value); + visitor.visit_id(key)?; + visitor.visit_id(value)?; } + + Ok!() } pub fn walk_intrinsic<'heap, V: Visitor<'heap> + ?Sized>( @@ -343,7 +372,7 @@ pub fn walk_intrinsic<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind, }: Type<'heap, IntrinsicType>, -) { +) -> V::Result { match kind { IntrinsicType::List(list) => visitor.visit_intrinsic_list(intrinsic.with(list)), IntrinsicType::Dict(dict) => visitor.visit_intrinsic_dict(intrinsic.with(dict)), @@ -357,28 +386,33 @@ pub fn walk_struct<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind: &StructType { fields }, }: Type<'heap, StructType>, -) { - visitor.visit_struct_fields(fields); +) -> V::Result { + visitor.visit_struct_fields(fields)?; + Ok!() } pub fn walk_struct_fields<'heap, V: Visitor<'heap> + ?Sized>( visitor: &mut V, fields: StructFields<'heap>, -) { +) -> V::Result { if !V::Filter::MEMBERS { - return; + return Ok!(); } for &field in fields.iter() { - visitor.visit_struct_field(field); + visitor.visit_struct_field(field)?; } + + Ok!() } pub fn walk_struct_field<'heap, V: Visitor<'heap> + ?Sized>( visitor: &mut V, StructField { name: _, value }: StructField<'heap>, -) { - visitor.visit_id(value); +) -> V::Result { + visitor.visit_id(value)?; + + Ok!() } pub fn walk_tuple<'heap, V: Visitor<'heap> + ?Sized>( @@ -388,14 +422,16 @@ pub fn walk_tuple<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind: &TupleType { fields }, }: Type<'heap, TupleType>, -) { +) -> V::Result { if !V::Filter::MEMBERS { - return; + return Ok!(); } for &field in fields { - visitor.visit_id(field); + visitor.visit_id(field)?; } + + Ok!() } pub fn walk_union<'heap, V: Visitor<'heap> + ?Sized>( @@ -405,10 +441,12 @@ pub fn walk_union<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind: &UnionType { variants }, }: Type<'heap, UnionType>, -) { +) -> V::Result { for &variant in variants { - visitor.visit_id(variant); + visitor.visit_id(variant)?; } + + Ok!() } pub fn walk_intersection<'heap, V: Visitor<'heap> + ?Sized>( @@ -418,10 +456,12 @@ pub fn walk_intersection<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind: &IntersectionType { variants }, }: Type<'heap, IntersectionType>, -) { +) -> V::Result { for &variant in variants { - visitor.visit_id(variant); + visitor.visit_id(variant)?; } + + Ok!() } pub fn walk_closure<'heap, V: Visitor<'heap> + ?Sized>( @@ -431,16 +471,17 @@ pub fn walk_closure<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind: &ClosureType { params, returns }, }: Type<'heap, ClosureType>, -) { +) -> V::Result { if !V::Filter::MEMBERS { - return; + return Ok!(); } for ¶m in params { - visitor.visit_id(param); + visitor.visit_id(param)?; } - visitor.visit_id(returns); + visitor.visit_id(returns)?; + Ok!() } pub fn walk_apply<'heap, V: Visitor<'heap> + ?Sized>( @@ -453,9 +494,11 @@ pub fn walk_apply<'heap, V: Visitor<'heap> + ?Sized>( substitutions, }, }: Type<'heap, Apply>, -) { - visitor.visit_generic_substitutions(substitutions); - visitor.visit_id(base); +) -> V::Result { + visitor.visit_generic_substitutions(substitutions)?; + visitor.visit_id(base)?; + + Ok!() } pub fn walk_generic<'heap, V: Visitor<'heap> + ?Sized>( @@ -465,9 +508,11 @@ pub fn walk_generic<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind: &Generic { base, arguments }, }: Type<'heap, Generic>, -) { - visitor.visit_generic_arguments(arguments); - visitor.visit_id(base); +) -> V::Result { + visitor.visit_generic_arguments(arguments)?; + visitor.visit_id(base)?; + + Ok!() } pub fn walk_param<'heap, V: Visitor<'heap> + ?Sized>( @@ -477,16 +522,17 @@ pub fn walk_param<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind: &Param { argument }, }: Type<'heap, Param>, -) { +) -> V::Result { if !V::Filter::SUBSTITUTIONS { - return; + return Ok!(); } let Some(substitution) = visitor.env().substitution.argument(argument) else { - return; + return Ok!(); }; - visitor.visit_id(substitution); + visitor.visit_id(substitution)?; + Ok!() } pub fn walk_infer<'heap, V: Visitor<'heap> + ?Sized>( @@ -496,14 +542,158 @@ pub fn walk_infer<'heap, V: Visitor<'heap> + ?Sized>( span: _, kind: &Infer { hole }, }: Type<'heap, Infer>, -) { +) -> V::Result { if !V::Filter::SUBSTITUTIONS { - return; + return Ok!(); } let Some(substitution) = visitor.env().substitution.infer(hole) else { - return; + return Ok!(); }; - visitor.visit_id(substitution); + visitor.visit_id(substitution)?; + Ok!() +} + +impl<'heap, V> Visitor<'heap> for &mut V +where + V: Visitor<'heap>, +{ + type Filter = V::Filter; + type Result = V::Result; + + fn env(&self) -> &Environment<'heap> { + V::env(self) + } + + fn visit_generic_arguments(&mut self, arguments: GenericArguments<'heap>) -> Self::Result { + V::visit_generic_arguments(self, arguments) + } + + fn visit_generic_argument(&mut self, argument: GenericArgument<'heap>) -> Self::Result { + V::visit_generic_argument(self, argument) + } + + fn visit_generic_substitutions( + &mut self, + substitutions: GenericSubstitutions<'heap>, + ) -> Self::Result { + V::visit_generic_substitutions(self, substitutions) + } + + fn visit_generic_substitution(&mut self, substitution: GenericSubstitution) -> Self::Result { + V::visit_generic_substitution(self, substitution) + } + + fn visit_id(&mut self, id: TypeId) -> Self::Result { + V::visit_id(self, id) + } + + fn visit_type(&mut self, r#type: Type<'heap>) -> Self::Result { + V::visit_type(self, r#type) + } + + fn visit_opaque(&mut self, opaque: Type<'heap, OpaqueType>) -> Self::Result { + V::visit_opaque(self, opaque) + } + + fn visit_primitive(&mut self, primitive: Type<'heap, PrimitiveType>) -> Self::Result { + V::visit_primitive(self, primitive) + } + + fn visit_intrinsic_list(&mut self, list: Type<'heap, ListType>) -> Self::Result { + V::visit_intrinsic_list(self, list) + } + + fn visit_intrinsic_dict(&mut self, dict: Type<'heap, DictType>) -> Self::Result { + V::visit_intrinsic_dict(self, dict) + } + + fn visit_intrinsic(&mut self, intrinsic: Type<'heap, IntrinsicType>) -> Self::Result { + V::visit_intrinsic(self, intrinsic) + } + + fn visit_struct(&mut self, r#struct: Type<'heap, StructType>) -> Self::Result { + V::visit_struct(self, r#struct) + } + + fn visit_struct_fields(&mut self, fields: StructFields<'heap>) -> Self::Result { + V::visit_struct_fields(self, fields) + } + + fn visit_struct_field(&mut self, field: StructField<'heap>) -> Self::Result { + V::visit_struct_field(self, field) + } + + fn visit_tuple(&mut self, tuple: Type<'heap, TupleType>) -> Self::Result { + V::visit_tuple(self, tuple) + } + + fn visit_union(&mut self, union: Type<'heap, UnionType>) -> Self::Result { + V::visit_union(self, union) + } + + fn visit_intersection(&mut self, intersection: Type<'heap, IntersectionType>) -> Self::Result { + V::visit_intersection(self, intersection) + } + + fn visit_closure(&mut self, closure: Type<'heap, ClosureType>) -> Self::Result { + V::visit_closure(self, closure) + } + + fn visit_apply(&mut self, apply: Type<'heap, Apply>) -> Self::Result { + V::visit_apply(self, apply) + } + + fn visit_generic(&mut self, generic: Type<'heap, Generic<'heap>>) -> Self::Result { + V::visit_generic(self, generic) + } + + fn visit_param(&mut self, param: Type<'heap, Param>) -> Self::Result { + V::visit_param(self, param) + } + + fn visit_infer(&mut self, infer: Type<'heap, Infer>) -> Self::Result { + V::visit_infer(self, infer) + } +} + +pub struct RecursiveVisitorGuard<'heap> { + boundary: RecursionBoundary<'heap>, +} + +impl<'heap> RecursiveVisitorGuard<'heap> { + #[must_use] + pub fn new() -> Self { + Self { + boundary: RecursionBoundary::new(), + } + } + + pub fn with>( + &mut self, + visit: impl FnOnce(&mut Self, Type<'heap>) -> T, + r#type: Type<'heap>, + ) -> T { + if self.boundary.enter(r#type, r#type).is_break() { + return Ok!(); + } + + let result = visit(self, r#type); + + self.boundary.exit(r#type, r#type); + result + } +} + +impl Default for RecursiveVisitorGuard<'_> { + fn default() -> Self { + Self::new() + } +} + +impl AsMut for RecursiveVisitorGuard<'_> { + fn as_mut(&mut self) -> &mut Self { + self + } } diff --git a/libs/@local/hashql/mir/package.json b/libs/@local/hashql/mir/package.json index ea18c59cd7e..0afdbfe530c 100644 --- a/libs/@local/hashql/mir/package.json +++ b/libs/@local/hashql/mir/package.json @@ -9,7 +9,7 @@ "fix:clippy": "just clippy --fix", "lint:clippy": "just clippy", "test:codspeed": "cargo codspeed run -p hashql-mir", - "test:miri": "cargo miri nextest run -- changed_bitor interpret::locals::tests", + "test:miri": "cargo miri nextest run -- changed_bitor interpret::locals::tests pass::analysis::execution::cost", "test:unit": "mise run test:unit @rust/hashql-mir" }, "dependencies": { diff --git a/libs/@local/hashql/mir/src/builder/base.rs b/libs/@local/hashql/mir/src/builder/base.rs index 1f3ef064de3..65c8d6c7deb 100644 --- a/libs/@local/hashql/mir/src/builder/base.rs +++ b/libs/@local/hashql/mir/src/builder/base.rs @@ -38,7 +38,7 @@ impl<'env, 'heap> BaseBuilder<'env, 'heap> { /// Creates an integer constant operand. #[must_use] - pub fn const_int(self, value: i128) -> Operand<'heap> { + pub const fn const_int(self, value: i128) -> Operand<'heap> { Operand::Constant(Constant::Int(value.into())) } @@ -52,7 +52,7 @@ impl<'env, 'heap> BaseBuilder<'env, 'heap> { /// Creates a boolean constant operand. #[must_use] - pub fn const_bool(self, value: bool) -> Operand<'heap> { + pub const fn const_bool(self, value: bool) -> Operand<'heap> { Operand::Constant(Constant::Int(value.into())) } diff --git a/libs/@local/hashql/mir/src/builder/body.rs b/libs/@local/hashql/mir/src/builder/body.rs index 7544f792ae3..c5b1051fe36 100644 --- a/libs/@local/hashql/mir/src/builder/body.rs +++ b/libs/@local/hashql/mir/src/builder/body.rs @@ -262,7 +262,7 @@ macro_rules! body { $interner:ident, $env:ident; $type:tt @ $id:tt / $arity:literal -> $body_type:tt { decl $($param:ident: $param_type:tt),*; - $(@proj $($proj:ident = $proj_base:ident.$field:literal: $proj_type:tt),*;)? + $(@proj $($proj:ident = $proj_base:ident.$field:tt: $proj_type:tt),*;)? $($block:ident($($block_param:ident),*) $block_body:tt),+ } @@ -278,7 +278,7 @@ macro_rules! body { $( $( - let $proj = builder.place(|p| p.from($proj_base).field($field, $crate::builder::body!(@type types; $proj_type))); + let $proj = builder.place($crate::builder::body!(@proj types; $proj_base; $field; $proj_type)); )* )? @@ -308,6 +308,13 @@ macro_rules! body { $id }; + (@proj $types:ident; $proj_base:ident; $proj:literal; $proj_type:tt) => { + |p| p.from($proj_base).field($proj, $crate::builder::body!(@type $types; $proj_type)) + }; + (@proj $types:ident; $proj_base:ident; $proj:ident; $proj_type:tt) => { + |p| p.from($proj_base).field_by_name(stringify!($proj), $crate::builder::body!(@type $types; $proj_type)) + }; + (@type $types:ident; Int) => { $types.integer() }; @@ -326,6 +333,9 @@ macro_rules! body { (@type $types:ident; [List $sub:tt]) => { $types.list($crate::builder::body!(@type $types; $sub)) }; + (@type $types:ident; [Opaque $sym:path; $value:tt]) => { + $types.opaque($sym, $crate::builder::body!(@type $types; $value)) + }; (@type $types:ident; [fn($($args:tt),+) -> $ret:tt]) => { $types.closure([$($crate::builder::body!(@type $types; $args)),*], $crate::builder::body!(@type $types; $ret)) }; diff --git a/libs/@local/hashql/mir/src/interpret/value/int.rs b/libs/@local/hashql/mir/src/interpret/value/int.rs index 2b6cd769cf5..0e04839a950 100644 --- a/libs/@local/hashql/mir/src/interpret/value/int.rs +++ b/libs/@local/hashql/mir/src/interpret/value/int.rs @@ -485,7 +485,7 @@ macro_rules! impl_from { }; (@impl $ty:ty) => { - impl From<$ty> for Int { + impl const From<$ty> for Int { #[inline] fn from(value: $ty) -> Self { Self::from_value_unchecked(i128::from(value)) @@ -498,26 +498,29 @@ impl_from!(bool, u8, u16, u32, u64, i8, i16, i32, i64, i128); // `usize` and `isize` cannot use the macro because `i128::from()` doesn't accept // platform-dependent types. -impl From for Int { +impl const From for Int { #[inline] fn from(value: usize) -> Self { Self::from_value_unchecked(value as i128) } } -impl From for Int { +impl const From for Int { #[inline] fn from(value: isize) -> Self { Self::from_value_unchecked(value as i128) } } -impl TryFrom for Int { +impl const TryFrom for Int { type Error = TryFromIntError; #[inline] fn try_from(value: u128) -> Result { - Ok(Self::from_value_unchecked(i128::try_from(value)?)) + match i128::try_from(value) { + Ok(value) => Ok(Self::from_value_unchecked(value)), + Err(error) => Err(error), + } } } diff --git a/libs/@local/hashql/mir/src/lib.rs b/libs/@local/hashql/mir/src/lib.rs index 0492e4ea09b..2f2fac13ead 100644 --- a/libs/@local/hashql/mir/src/lib.rs +++ b/libs/@local/hashql/mir/src/lib.rs @@ -11,6 +11,7 @@ impl_trait_in_assoc_type, macro_metavar_expr_concat, never_type, + const_trait_impl, // Library Features allocator_api, @@ -30,6 +31,8 @@ step_trait, string_from_utf8_lossy_owned, try_trait_v2, + temporary_niche_types, + const_convert )] #![expect(clippy::indexing_slicing)] extern crate alloc; diff --git a/libs/@local/hashql/mir/src/pass/analysis/dataflow/lattice/impls.rs b/libs/@local/hashql/mir/src/pass/analysis/dataflow/lattice/impls.rs index 986d5b8cb26..df40960b7d3 100644 --- a/libs/@local/hashql/mir/src/pass/analysis/dataflow/lattice/impls.rs +++ b/libs/@local/hashql/mir/src/pass/analysis/dataflow/lattice/impls.rs @@ -1,5 +1,7 @@ //! Built-in semiring implementations for common numeric types. +use core::cmp::Reverse; + use hashql_core::id::{ Id, bit_vec::{BitRelations as _, ChunkedBitSet, DenseBitSet, MixedBitSet}, @@ -193,6 +195,54 @@ macro_rules! impl_bitset { impl_bitset!(DenseBitSet, ChunkedBitSet, MixedBitSet); +impl MeetSemiLattice for Reverse +where + U: JoinSemiLattice, +{ + #[inline] + fn meet(&self, lhs: &mut T, rhs: &T) -> bool { + self.0.join(lhs, rhs) + } +} + +impl HasBottom for Reverse +where + U: HasTop, +{ + #[inline] + fn bottom(&self) -> T { + self.0.top() + } + + fn is_bottom(&self, value: &T) -> bool { + self.0.is_top(value) + } +} + +impl JoinSemiLattice for Reverse +where + U: MeetSemiLattice, +{ + #[inline] + fn join(&self, lhs: &mut T, rhs: &T) -> bool { + self.0.meet(lhs, rhs) + } +} + +impl HasTop for Reverse +where + U: HasBottom, +{ + #[inline] + fn top(&self) -> T { + self.0.bottom() + } + + fn is_top(&self, value: &T) -> bool { + self.0.is_bottom(value) + } +} + #[cfg(test)] mod tests { #![expect(clippy::min_ident_chars)] diff --git a/libs/@local/hashql/mir/src/pass/analysis/dataflow/liveness/tests.rs b/libs/@local/hashql/mir/src/pass/analysis/dataflow/liveness/tests.rs index 3a6870ba37d..cb9cd7c0bf0 100644 --- a/libs/@local/hashql/mir/src/pass/analysis/dataflow/liveness/tests.rs +++ b/libs/@local/hashql/mir/src/pass/analysis/dataflow/liveness/tests.rs @@ -17,7 +17,7 @@ use crate::{ builder::body, intern::Interner, pass::analysis::dataflow::framework::{DataflowAnalysis as _, DataflowResults, Direction}, - pretty::TextFormat, + pretty::TextFormatOptions, }; fn format_liveness_state(mut write: impl fmt::Write, state: &DenseBitSet) -> fmt::Result { @@ -63,12 +63,14 @@ fn format_liveness( fn format_body<'heap>(env: &Environment<'heap>, body: &Body<'heap>) -> impl Display { let formatter = Formatter::new(env.heap); - let mut text_formatter = TextFormat { + let mut text_formatter = TextFormatOptions { writer: Vec::::new(), indent: 4, sources: (), types: TypeFormatter::new(&formatter, env, TypeFormatterOptions::terse()), - }; + annotations: (), + } + .build(); text_formatter.format_body(body).expect("infallible"); diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/cost.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/cost.rs new file mode 100644 index 00000000000..59328e7820f --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/cost.rs @@ -0,0 +1,385 @@ +//! Cost tracking for execution planning. +//! +//! Provides data structures for recording the execution cost of statements on different targets. +//! The execution planner uses these costs to select optimal targets for each statement. + +use alloc::alloc::Global; +use core::{ + alloc::Allocator, + fmt, iter, + ops::{Index, IndexMut}, +}; + +use hashql_core::id::{Id as _, bit_vec::DenseBitSet}; + +use crate::{ + body::{ + Body, + basic_block::BasicBlockSlice, + basic_blocks::BasicBlocks, + local::{Local, LocalVec}, + location::Location, + }, + pass::transform::Traversals, +}; + +/// Execution cost for a statement on a particular target. +/// +/// Lower values indicate cheaper execution. When multiple targets can execute a statement, the +/// execution planner selects the target with the lowest cost. A statement with no assigned cost +/// (`None`) indicates the target cannot execute that statement. +/// +/// Typical cost values: +/// - `0`: Zero-cost operations (storage markers, nops) +/// - `4`: Standard Postgres/Embedding operations +/// - `8`: Interpreter operations (higher due to runtime overhead) +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct Cost(core::num::niche_types::U32NotAllOnes); + +impl Cost { + /// Creates a cost from a `u32` value, returning `None` if the value is `u32::MAX`. + /// + /// The `u32::MAX` value is reserved as a niche for `Option` optimization. + #[must_use] + pub const fn new(value: u32) -> Option { + match core::num::niche_types::U32NotAllOnes::new(value) { + Some(cost) => Some(Self(cost)), + None => None, + } + } + + #[must_use] + #[doc(hidden)] + #[track_caller] + pub const fn new_panic(value: u32) -> Self { + match core::num::niche_types::U32NotAllOnes::new(value) { + Some(cost) => Self(cost), + None => panic!("invalid cost value"), + } + } + + /// Creates a cost without checking whether the value is valid. + /// + /// # Safety + /// + /// The caller must ensure `value` is not `u32::MAX`. + #[must_use] + #[expect(unsafe_code)] + pub const unsafe fn new_unchecked(value: u32) -> Self { + // SAFETY: The caller must ensure `value` is not `u32::MAX`. + Self(unsafe { core::num::niche_types::U32NotAllOnes::new_unchecked(value) }) + } +} + +impl fmt::Display for Cost { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0.as_inner(), fmt) + } +} + +/// Sparse cost map for traversal locals. +/// +/// Traversals are locals that require data fetching from a backend (e.g., entity field access). +/// This map only stores costs for locals marked as traversals; insertions for non-traversal +/// locals are ignored. This allows the execution planner to focus on the operations that actually +/// require backend coordination. +pub struct TraversalCostVec { + traversals: DenseBitSet, + costs: LocalVec, A>, +} + +impl TraversalCostVec { + /// Creates an empty traversal cost map for the given body. + /// + /// Only locals that are enabled traversals (per [`Traversals::enabled`]) will accept cost + /// insertions; other locals are silently ignored. + pub fn new<'heap>(body: &Body<'heap>, traversals: &Traversals<'heap>, alloc: A) -> Self { + Self { + traversals: traversals.enabled(body), + costs: LocalVec::new_in(alloc), + } + } + + /// Records a cost for a traversal local. + /// + /// If `local` is not a traversal, the insertion is silently ignored. + pub fn insert(&mut self, local: Local, cost: Cost) { + if self.traversals.contains(local) { + self.costs.insert(local, cost); + } + } + + pub fn iter(&self) -> impl Iterator { + self.costs + .iter_enumerated() + .filter_map(|(local, cost)| cost.map(|cost| (local, cost))) + } +} + +impl IntoIterator for &TraversalCostVec { + type Item = (Local, Cost); + + type IntoIter = impl Iterator; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +/// Dense cost map for all statements in a body. +/// +/// Stores the execution cost for every statement, indexed by [`Location`]. A `None` cost +/// indicates the target cannot execute that statement. The execution planner compares costs +/// across targets to determine the optimal execution strategy. +/// +/// Internally uses a flattened representation with per-block offsets for efficient indexing. +pub struct StatementCostVec { + offsets: Box, A>, + costs: Vec, A>, +} + +impl StatementCostVec { + #[expect(unsafe_code)] + fn from_iter(mut iter: impl ExactSizeIterator, alloc: A) -> Self + where + A: Clone, + { + let mut offsets = Box::new_uninit_slice_in(iter.len() + 1, alloc.clone()); + + let mut offset = 0_u32; + + offsets[0].write(0); + + let (_, rest) = offsets[1..].write_iter(iter::from_fn(|| { + let next = iter.next()?; + + offset += next; + + Some(offset) + })); + + debug_assert!(rest.is_empty()); + debug_assert_eq!(iter.len(), 0); + + let costs = alloc::vec::from_elem_in(None, offset as usize, alloc); + + // SAFETY: We have initialized all elements of the slice. + let offsets = unsafe { offsets.assume_init() }; + let offsets = BasicBlockSlice::from_boxed_slice(offsets); + + Self { offsets, costs } + } + + /// Creates a cost map with space for all statements in the given blocks. + /// + /// All costs are initialized to `None` (unsupported). Use indexing to assign costs. + #[expect(clippy::cast_possible_truncation)] + pub fn new(blocks: &BasicBlocks, alloc: A) -> Self + where + A: Clone, + { + Self::from_iter( + blocks.iter().map(|block| block.statements.len() as u32), + alloc, + ) + } + + pub fn is_empty(&self) -> bool { + self.costs.iter().all(Option::is_none) + } + + /// Returns the cost at `location`, or `None` if out of bounds or unassigned. + pub fn get(&self, location: Location) -> Option { + let range = (self.offsets[location.block] as usize) + ..(self.offsets[location.block.plus(1)] as usize); + + // statement_index is 1-based + self.costs[range] + .get(location.statement_index - 1) + .copied() + .flatten() + } +} + +impl Index for StatementCostVec { + type Output = Option; + + fn index(&self, index: Location) -> &Self::Output { + let range = + (self.offsets[index.block] as usize)..(self.offsets[index.block.plus(1)] as usize); + + // statement_index is 1-based + &self.costs[range][index.statement_index - 1] + } +} + +impl IndexMut for StatementCostVec { + fn index_mut(&mut self, index: Location) -> &mut Self::Output { + let range = + (self.offsets[index.block] as usize)..(self.offsets[index.block.plus(1)] as usize); + + // statement_index is 1-based + &mut self.costs[range][index.statement_index - 1] + } +} + +#[cfg(test)] +mod tests { + use alloc::alloc::Global; + + use super::{Cost, StatementCostVec}; + use crate::body::{basic_block::BasicBlockId, location::Location}; + + /// `Cost::new` succeeds for valid values (0 and 100). + #[test] + fn cost_new_valid_values() { + let zero = Cost::new(0); + assert!(zero.is_some()); + + let hundred = Cost::new(100); + assert!(hundred.is_some()); + } + + /// `Cost::new(u32::MAX)` returns `None` (reserved as niche for `Option`). + #[test] + fn cost_new_max_returns_none() { + let max = Cost::new(u32::MAX); + assert!(max.is_none()); + } + + /// `Cost::new(u32::MAX - 1)` succeeds (largest valid cost value). + #[test] + fn cost_new_max_minus_one_valid() { + let max_valid = Cost::new(u32::MAX - 1); + assert!(max_valid.is_some()); + } + + /// `Cost::new_unchecked` with valid values works correctly. + /// + /// This test exercises unsafe code and should be run under Miri. + #[test] + #[expect(unsafe_code)] + fn cost_new_unchecked_valid() { + // SAFETY: 0 is not u32::MAX + let zero = unsafe { Cost::new_unchecked(0) }; + assert_eq!(Cost::new(0), Some(zero)); + + // SAFETY: 100 is not u32::MAX + let hundred = unsafe { Cost::new_unchecked(100) }; + assert_eq!(Cost::new(100), Some(hundred)); + } + + /// `StatementCostVec` correctly indexes by `Location` across multiple blocks. + #[test] + fn statement_cost_vec_indexing() { + // bb0: 2 statements, bb1: 3 statements, bb2: 1 statement + let mut costs = StatementCostVec::from_iter([2, 3, 1].into_iter(), Global); + + // Assign costs at various locations + let loc_0_1 = Location { + block: BasicBlockId::new(0), + statement_index: 1, + }; + let loc_0_2 = Location { + block: BasicBlockId::new(0), + statement_index: 2, + }; + let loc_1_2 = Location { + block: BasicBlockId::new(1), + statement_index: 2, + }; + let loc_2_1 = Location { + block: BasicBlockId::new(2), + statement_index: 1, + }; + + costs[loc_0_1] = Some(cost!(10)); + costs[loc_0_2] = Some(cost!(20)); + costs[loc_1_2] = Some(cost!(30)); + costs[loc_2_1] = Some(cost!(40)); + + // Verify retrieval + assert_eq!(costs.get(loc_0_1), Some(cost!(10))); + assert_eq!(costs.get(loc_0_2), Some(cost!(20))); + assert_eq!(costs.get(loc_1_2), Some(cost!(30))); + assert_eq!(costs.get(loc_2_1), Some(cost!(40))); + + // Unassigned locations return None + let loc_1_1 = Location { + block: BasicBlockId::new(1), + statement_index: 1, + }; + assert_eq!(costs.get(loc_1_1), None); + } + + /// `StatementCostVec` initialization with a single block. + /// + /// This test exercises unsafe code and should be run under Miri. + #[test] + fn statement_cost_vec_init_single_block() { + // Single block with 5 statements + let mut costs = StatementCostVec::from_iter([5].into_iter(), Global); + + // All 5 statements should be accessible + for index in 1..=5_u32 { + let location = Location { + block: BasicBlockId::new(0), + statement_index: index as usize, + }; + + costs[location] = Some(Cost::new(index).expect("should be non-zero")); + } + + for index in 1..=5 { + let location = Location { + block: BasicBlockId::new(0), + statement_index: index as usize, + }; + + assert_eq!(costs.get(location), Cost::new(index)); + } + } + + /// `StatementCostVec` initialization with multiple blocks of varying sizes. + /// + /// This test exercises unsafe code and should be run under Miri. + #[test] + fn statement_cost_vec_init_multiple_blocks() { + // 0 statements, 1 statement, 5 statements + let mut costs = StatementCostVec::from_iter([0, 1, 5].into_iter(), Global); + + // bb1 has 1 statement + let loc_1_1 = Location { + block: BasicBlockId::new(1), + statement_index: 1, + }; + costs[loc_1_1] = Some(cost!(100)); + assert_eq!(costs.get(loc_1_1), Some(cost!(100))); + + // bb2 has 5 statements + for index in 1..=5 { + let location = Location { + block: BasicBlockId::new(2), + statement_index: index as usize, + }; + + costs[location] = Some(Cost::new(index).expect("non-zero")); + } + for index in 1..=5 { + let location = Location { + block: BasicBlockId::new(2), + statement_index: index as usize, + }; + assert_eq!(costs.get(location), Cost::new(index)); + } + } + + /// `StatementCostVec` initialization with zero blocks. + /// + /// This test exercises unsafe code and should be run under Miri. + #[test] + fn statement_cost_vec_init_empty() { + // Should not panic + let _costs = StatementCostVec::from_iter(core::iter::empty::(), Global); + } +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/mod.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/mod.rs new file mode 100644 index 00000000000..71f9f4db7a6 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/mod.rs @@ -0,0 +1,11 @@ +macro_rules! cost { + ($value:expr) => { + const { $crate::pass::analysis::execution::cost::Cost::new_panic($value) } + }; +} + +mod cost; +pub mod statement_placement; +pub mod target; + +pub use self::cost::{Cost, StatementCostVec, TraversalCostVec}; diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/common.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/common.rs new file mode 100644 index 00000000000..abf70d07c23 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/common.rs @@ -0,0 +1,231 @@ +use core::{alloc::Allocator, cell::Cell, cmp::Reverse}; + +use hashql_core::{ + heap::Heap, + id::bit_vec::{BitRelations as _, DenseBitSet}, +}; + +use crate::{ + body::{ + Body, + basic_block::BasicBlockId, + local::Local, + location::Location, + operand::Operand, + rvalue::RValue, + statement::{Assign, Statement, StatementKind}, + terminator::TerminatorKind, + }, + context::MirContext, + pass::analysis::{ + dataflow::{ + framework::{DataflowAnalysis, DataflowResults}, + lattice::PowersetLattice, + }, + execution::{Cost, StatementCostVec, cost::TraversalCostVec}, + }, + visit::Visitor, +}; + +/// Single-use value wrapper ensuring a value is consumed exactly once. +pub(crate) struct OnceValue(Cell>); + +impl OnceValue { + pub(crate) const fn new(value: T) -> Self { + Self(Cell::new(Some(value))) + } + + fn take(&self) -> T { + self.0.take().expect("TakeCell already taken") + } +} + +type RValueFn<'heap> = + fn(&MirContext<'_, 'heap>, &Body<'heap>, &DenseBitSet, &RValue<'heap>) -> bool; + +type OperandFn<'heap> = + fn(&MirContext<'_, 'heap>, &Body<'heap>, &DenseBitSet, &Operand<'heap>) -> bool; + +/// Computes which locals can be dispatched to an execution target. +/// +/// This is a "must" analysis: a local is only considered dispatchable if it is supported along +/// *all* paths reaching return blocks. If any path produces an unsupported value for a local, +/// that local is excluded from the dispatchable set. +/// +/// The analysis is parameterized by target-specific predicates that determine whether individual +/// rvalues and operands are supported by that target. +/// +/// Values flowing through [`GraphRead`] edges are always marked as unsupported, since graph +/// reads must be executed by the interpreter and cannot be dispatched to external backends. +/// +/// [`GraphRead`]: crate::body::terminator::GraphRead +pub(crate) struct SupportedAnalysis<'ctx, 'env, 'heap, B> { + pub body: &'ctx Body<'heap>, + pub context: &'ctx MirContext<'env, 'heap>, + + pub is_supported_rvalue: RValueFn<'heap>, + pub is_supported_operand: OperandFn<'heap>, + pub initialize_boundary: OnceValue, +} + +impl<'heap, B> SupportedAnalysis<'_, '_, 'heap, B> { + /// Runs the analysis and returns the set of dispatchable locals. + /// + /// A local is dispatchable only if it is supported at every return block. + pub(crate) fn finish_in(self, alloc: A) -> DenseBitSet + where + B: FnOnce(&Body<'heap>, &mut DenseBitSet), + { + let body = self.body; + let DataflowResults { exit_states, .. } = self.iterate_to_fixpoint_in(body, alloc); + + let mut has_return = false; + let mut dispatchable = DenseBitSet::new_filled(body.local_decls.len()); + + for (bb, state) in exit_states.iter_enumerated() { + if matches!( + body.basic_blocks[bb].terminator.kind, + TerminatorKind::Return(_) + ) { + dispatchable.intersect(state); + has_return = true; + } + } + + if !has_return { + dispatchable.clear(); + } + + dispatchable + } +} + +impl<'heap, B> DataflowAnalysis<'heap> for SupportedAnalysis<'_, '_, 'heap, B> +where + B: FnOnce(&Body<'heap>, &mut DenseBitSet), +{ + type Domain = DenseBitSet; + type Lattice = Reverse; + type Metadata = !; + type SwitchIntData = !; + + fn lattice_in(&self, body: &Body<'heap>, _: A) -> Self::Lattice { + Reverse(PowersetLattice::new(body.local_decls.len())) + } + + fn initialize_boundary( + &self, + body: &Body<'heap>, + domain: &mut Self::Domain, + _: A, + ) { + let initialize_boundary = self.initialize_boundary.take(); + + (initialize_boundary)(body, domain); + } + + fn transfer_statement( + &self, + _: Location, + statement: &Statement<'heap>, + state: &mut Self::Domain, + ) { + let StatementKind::Assign(Assign { lhs, rhs }) = &statement.kind else { + return; + }; + + assert!( + lhs.projections.is_empty(), + "MIR must be in MIR(SSA) form for analysis to take place" + ); + + let is_supported = (self.is_supported_rvalue)(self.context, self.body, state, rhs); + if is_supported { + state.insert(lhs.local); + } else { + state.remove(lhs.local); + } + } + + fn transfer_edge( + &self, + _: BasicBlockId, + source_args: &[Operand<'heap>], + + _: BasicBlockId, + target_params: &[Local], + + state: &mut Self::Domain, + ) { + debug_assert_eq!(source_args.len(), target_params.len()); + + for (arg, ¶m) in source_args.iter().zip(target_params) { + let is_supported = (self.is_supported_operand)(self.context, self.body, state, arg); + state.set(param, is_supported); + } + } + + fn transfer_graph_read_edge( + &self, + _: BasicBlockId, + + _: BasicBlockId, + target_params: &[Local], + + state: &mut Self::Domain, + ) { + // Graph reads must happen inside of the interpreter, and are therefore not supported on any + // backend. + for ¶m in target_params { + state.remove(param); + } + } +} + +/// Assigns costs to statements based on the dispatchable set. +/// +/// After the supportedness analysis computes which locals are dispatchable, this visitor walks +/// the body and assigns costs. A statement receives a cost if its rvalue is supported given the +/// dispatchable locals; otherwise it gets `None`. Storage statements always receive zero cost. +pub(crate) struct CostVisitor<'ctx, 'env, 'heap> { + pub body: &'ctx Body<'heap>, + pub context: &'ctx MirContext<'env, 'heap>, + pub dispatchable: &'ctx DenseBitSet, + pub cost: Cost, + + pub statement_costs: StatementCostVec<&'heap Heap>, + pub traversal_costs: TraversalCostVec<&'heap Heap>, + + pub is_supported_rvalue: RValueFn<'heap>, +} + +impl<'heap> Visitor<'heap> for CostVisitor<'_, '_, 'heap> { + type Result = Result<(), !>; + + fn visit_statement( + &mut self, + location: Location, + statement: &Statement<'heap>, + ) -> Self::Result { + match &statement.kind { + StatementKind::Assign(Assign { lhs, rhs }) => { + let cost = + (self.is_supported_rvalue)(self.context, self.body, self.dispatchable, rhs) + .then_some(self.cost); + + if let Some(cost) = cost + && lhs.projections.is_empty() + { + self.traversal_costs.insert(lhs.local, cost); + } + + self.statement_costs[location] = cost; + } + StatementKind::StorageDead(_) | StatementKind::StorageLive(_) | StatementKind::Nop => { + self.statement_costs[location] = Some(cost!(0)); + } + } + + Ok(()) + } +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/embedding/mod.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/embedding/mod.rs new file mode 100644 index 00000000000..f565d72713d --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/embedding/mod.rs @@ -0,0 +1,165 @@ +use core::alloc::Allocator; + +use hashql_core::{ + heap::Heap, + id::{Id as _, bit_vec::DenseBitSet}, + symbol::sym, +}; + +use super::{ + StatementPlacement, + common::{CostVisitor, OnceValue, SupportedAnalysis}, +}; +use crate::{ + body::{Body, Source, local::Local, operand::Operand, place::Place, rvalue::RValue}, + context::MirContext, + pass::{ + analysis::execution::{ + Cost, StatementCostVec, + cost::TraversalCostVec, + statement_placement::lookup::{Access, entity_projection_access}, + target::Embedding, + }, + transform::Traversals, + }, + visit::Visitor as _, +}; + +#[cfg(test)] +mod tests; + +fn is_supported_place<'heap>( + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + domain: &DenseBitSet, + place: &Place<'heap>, +) -> bool { + // For GraphReadFilter bodies, local 1 is the filter argument (vertex). Check if the + // projection path maps to an Embedding-accessible field. + if matches!(body.source, Source::GraphReadFilter(_)) && place.local.as_usize() == 1 { + let local_type = body.local_decls[place.local].r#type; + let type_name = context + .env + .r#type(local_type) + .kind + .opaque() + .map_or_else(|| unreachable!(), |opaque| opaque.name); + + if type_name == sym::path::Entity { + return matches!( + entity_projection_access(&place.projections), + Some(Access::Embedding(_)) + ); + } + + unimplemented!("unimplemented lookup for declared type") + } + + domain.contains(place.local) +} + +fn is_supported_operand<'heap>( + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + domain: &DenseBitSet, + operand: &Operand<'heap>, +) -> bool { + match operand { + Operand::Place(place) => is_supported_place(context, body, domain, place), + Operand::Constant(_) => false, + } +} + +fn is_supported_rvalue<'heap>( + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + domain: &DenseBitSet, + rvalue: &RValue<'heap>, +) -> bool { + match rvalue { + RValue::Load(operand) => is_supported_operand(context, body, domain, operand), + RValue::Input(_) + | RValue::Aggregate(_) + | RValue::Binary(_) + | RValue::Unary(_) + | RValue::Apply(_) => false, + } +} + +/// Statement placement for the [`Embedding`] execution target. +/// +/// Only supports loading from entity projections that access the `encodings.vectors` path. +/// No arguments are transferable, and no other operations are supported. +pub struct EmbeddingStatementPlacement { + statement_cost: Cost, +} +impl Default for EmbeddingStatementPlacement { + fn default() -> Self { + Self { + statement_cost: cost!(4), + } + } +} + +impl<'heap, A: Allocator + Clone> StatementPlacement<'heap, A> for EmbeddingStatementPlacement { + type Target = Embedding; + + fn statement_placement( + &mut self, + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + traversals: &Traversals<'heap>, + alloc: A, + ) -> (TraversalCostVec<&'heap Heap>, StatementCostVec<&'heap Heap>) { + let statement_costs = StatementCostVec::new(&body.basic_blocks, context.heap); + let traversal_costs = TraversalCostVec::new(body, traversals, context.heap); + + match body.source { + Source::GraphReadFilter(_) => {} + Source::Ctor(_) | Source::Closure(..) | Source::Thunk(..) | Source::Intrinsic(_) => { + return (traversal_costs, statement_costs); + } + } + + let dispatchable = SupportedAnalysis { + body, + context, + is_supported_rvalue, + is_supported_operand, + initialize_boundary: OnceValue::new( + |body: &Body<'heap>, domain: &mut DenseBitSet| { + match body.source { + Source::GraphReadFilter(_) => {} + Source::Ctor(_) + | Source::Closure(..) + | Source::Thunk(..) + | Source::Intrinsic(_) => return, + } + + debug_assert_eq!(body.args, 2); + + // Embedding backend cannot receive any arguments directly + for arg in 0..body.args { + domain.remove(Local::new(arg)); + } + }, + ), + } + .finish_in(alloc); + + let mut visitor = CostVisitor { + body, + context, + dispatchable: &dispatchable, + cost: self.statement_cost, + + statement_costs, + traversal_costs, + + is_supported_rvalue, + }; + visitor.visit_body(body); + + (visitor.traversal_costs, visitor.statement_costs) + } +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/embedding/tests.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/embedding/tests.rs new file mode 100644 index 00000000000..1d113b7b33f --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/embedding/tests.rs @@ -0,0 +1,239 @@ +//! Tests for [`EmbeddingStatementPlacement`]. +#![expect(clippy::min_ident_chars)] + +use hashql_core::{heap::Heap, symbol::sym, r#type::environment::Environment}; +use hashql_diagnostics::DiagnosticIssues; + +use crate::{ + builder::body, + context::MirContext, + def::DefId, + intern::Interner, + pass::analysis::execution::statement_placement::{ + EmbeddingStatementPlacement, + tests::{assert_placement, run_placement}, + }, +}; + +/// Only `entity.encodings.vectors` projection is supported. +/// +/// Tests that the embedding backend only supports loading from the `encodings.vectors` +/// path on entities. This is the only field stored in the embedding database. +#[test] +fn only_vectors_projection_supported() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> ? { + decl env: (), vertex: [Opaque sym::path::Entity; ?], vectors: ?; + @proj encodings = vertex.encodings: ?, vectors_proj = encodings.vectors: ?; + + bb0() { + vectors = load vectors_proj; + return vectors; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = EmbeddingStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "only_vectors_projection_supported", + "embedding", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// Both env (local 0) and entity (local 1) are excluded from the dispatchable set. +/// +/// Tests that `initialize_boundary` removes both argument locals from the domain. +/// The embedding backend cannot receive any arguments directly - it only accesses +/// entity fields through projections. +#[test] +fn all_args_excluded() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (Int), vertex: [Opaque sym::path::Entity; ?], env_val: Int, result: Bool; + @proj env_0 = env.0: Int; + + bb0() { + env_val = load env_0; + result = bin.== env_val 42; + return result; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = EmbeddingStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "all_args_excluded", + "embedding", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// Entity projections other than `encodings.vectors` are rejected. +/// +/// Tests that accessing entity fields like `metadata.archived` or `properties` +/// returns no cost for embedding - these paths map to Postgres, not Embedding. +#[test] +fn non_vectors_entity_projection_rejected() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (), vertex: [Opaque sym::path::Entity; ?], archived: Bool; + @proj metadata = vertex.metadata: ?, archived_proj = metadata.archived: Bool; + + bb0() { + archived = load archived_proj; + return archived; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = EmbeddingStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "non_vectors_entity_projection_rejected", + "embedding", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// `StorageLive`/`StorageDead` statements get `cost!(0)`. +/// +/// Tests that storage management statements have zero cost even for Embedding, +/// matching the interpreter behavior. +#[test] +fn storage_statements_zero_cost() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> ? { + decl env: (), vertex: [Opaque sym::path::Entity; ?], vectors: ?; + @proj encodings = vertex.encodings: ?, vectors_proj = encodings.vectors: ?; + + bb0() { + let (vectors.local); + vectors = load vectors_proj; + drop (vectors.local); + return vectors; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = EmbeddingStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "storage_statements_zero_cost", + "embedding", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// All operations except `Load` of vectors projection are rejected. +/// +/// Tests that Binary, Unary, Aggregate, Apply, Input, and constants all return no cost. +/// The embedding backend is extremely limited - it can only load vector data. +#[test] +fn other_operations_rejected() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let def_id = DefId::new(123); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (Int), vertex: [Opaque sym::path::Entity; ?], + x: Int, y: Int, sum: Int, neg: Int, + tup: (Int, Int), param: Int, + capture: (Int), func: [fn(Int) -> Int], call_result: Int, + result: Bool; + + bb0() { + x = load 10; + y = load 20; + sum = bin.+ x y; + neg = un.neg sum; + tup = tuple 1, 2; + param = input.load! "param"; + capture = load env; + func = closure def_id capture; + call_result = apply func, 1; + result = load true; + return result; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = EmbeddingStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "other_operations_rejected", + "embedding", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/interpret/mod.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/interpret/mod.rs new file mode 100644 index 00000000000..1640c1b9360 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/interpret/mod.rs @@ -0,0 +1,91 @@ +use core::alloc::Allocator; + +use hashql_core::heap::Heap; + +use super::StatementPlacement; +use crate::{ + body::{ + Body, + location::Location, + statement::{Statement, StatementKind}, + }, + context::MirContext, + pass::{ + analysis::execution::{ + cost::{Cost, StatementCostVec, TraversalCostVec}, + target::Interpreter, + }, + transform::Traversals, + }, + visit::Visitor, +}; + +#[cfg(test)] +mod tests; + +struct CostVisitor<'heap> { + cost: Cost, + + statement_costs: StatementCostVec<&'heap Heap>, +} + +impl<'heap> Visitor<'heap> for CostVisitor<'heap> { + type Result = Result<(), !>; + + fn visit_statement( + &mut self, + location: Location, + statement: &Statement<'heap>, + ) -> Self::Result { + // All statements are supported; TraversalExtraction provides backend data access + match &statement.kind { + StatementKind::Assign(_) => { + self.statement_costs[location] = Some(self.cost); + } + StatementKind::StorageDead(_) | StatementKind::StorageLive(_) | StatementKind::Nop => { + self.statement_costs[location] = Some(cost!(0)); + } + } + + Ok(()) + } +} + +/// Statement placement for the [`Interpreter`] execution target. +/// +/// Supports all statements unconditionally, serving as the universal fallback. +#[derive(Debug)] +pub struct InterpreterStatementPlacement { + statement_cost: Cost, +} + +impl Default for InterpreterStatementPlacement { + fn default() -> Self { + Self { + statement_cost: cost!(8), + } + } +} + +impl<'heap, A: Allocator> StatementPlacement<'heap, A> for InterpreterStatementPlacement { + type Target = Interpreter; + + fn statement_placement( + &mut self, + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + traversals: &Traversals<'heap>, + _: A, + ) -> (TraversalCostVec<&'heap Heap>, StatementCostVec<&'heap Heap>) { + let statement_costs = StatementCostVec::new(&body.basic_blocks, context.heap); + let traversal_costs = TraversalCostVec::new(body, traversals, context.heap); + + let mut visitor = CostVisitor { + cost: self.statement_cost, + statement_costs, + }; + visitor.visit_body(body); + + (traversal_costs, visitor.statement_costs) + } +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/interpret/tests.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/interpret/tests.rs new file mode 100644 index 00000000000..6cf1374802b --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/interpret/tests.rs @@ -0,0 +1,119 @@ +//! Tests for [`InterpreterStatementPlacement`]. +#![expect(clippy::min_ident_chars)] + +use hashql_core::{heap::Heap, symbol::sym, r#type::environment::Environment}; +use hashql_diagnostics::DiagnosticIssues; + +use crate::{ + builder::body, + context::MirContext, + def::DefId, + intern::Interner, + pass::analysis::execution::statement_placement::{ + InterpreterStatementPlacement, + tests::{assert_placement, run_placement}, + }, +}; + +/// All statement kinds receive costs (universal fallback). +/// +/// Tests that the interpreter supports all `RValue` kinds: Load, Binary, Unary, +/// Aggregate (tuple/struct/closure), Apply, and Input. Every assignment gets +/// a cost because the interpreter is the universal fallback target. +#[test] +fn all_statements_supported() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let def_id = DefId::new(42); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (Int), vertex: [Opaque sym::path::Entity; ?], + x: Int, y: Int, sum: Int, neg: Int, + tup: (Int, Int), s: (a: Int, b: Int), + capture: (Int), func: [fn(Int) -> Int], call_result: Int, + param: Int, result: Bool; + + bb0() { + x = load 10; + y = load 20; + sum = bin.+ x y; + neg = un.neg sum; + tup = tuple 1, 2; + s = struct a: 3, b: 4; + capture = load env; + func = closure def_id capture; + call_result = apply func, 5; + param = input.load! "param"; + result = load true; + return result; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = InterpreterStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "all_statements_supported", + "interpret", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// `StorageLive`/`StorageDead`/`Nop` get `cost!(0)`, assignments get `cost!(8)`. +/// +/// Tests the cost differentiation: storage management statements have zero cost +/// because they don't perform computation, while assignments have cost 8. +#[test] +fn storage_statements_zero_cost() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Int { + decl env: (), vertex: [Opaque sym::path::Entity; ?], x: Int, y: Int, result: Int; + + bb0() { + let (x.local); + x = load 10; + let (y.local); + y = load 20; + result = bin.+ x y; + drop (x.local); + drop (y.local); + return result; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = InterpreterStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "storage_statements_zero_cost", + "interpret", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/entity.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/entity.rs new file mode 100644 index 00000000000..75b1248858a --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/entity.rs @@ -0,0 +1,172 @@ +use hashql_core::symbol::sym; + +use super::trie::{Access, AccessMode, PathNode}; + +/// Entity path access trie mapping field paths to backend access types. +/// +/// The trie structure mirrors the entity schema, with paths mapping to their storage location: +/// +/// - `properties` → JSONB column in `entity_editions` +/// - `encodings.vectors` → Embedding backend +/// - `metadata.*` → Various columns in `entity_temporal_metadata`, `entity_editions`, etc. +/// - `link_data.*` → `entity_edge` table via joins +/// +/// Entry point is the `entity_temporal_metadata` table which joins to `entity_ids`, +/// `entity_editions`, `entity_is_of_type`, and `entity_edge`. +// The static ref here is required, so that the symbols are not duplicated across crates and have +// the same interned string. +pub(super) static ENTITY_PATHS: PathNode = PathNode::root(&[ + // entity_editions.properties (JSONB) + PathNode::jsonb(&sym::lexical::properties), + // (tbd) encodings + PathNode::branch( + &sym::lexical::encodings, + None, + &[ + // Vectors are stored outside the entity inside of an embeddings database + PathNode::branch( + &sym::lexical::vectors, + Access::Embedding(AccessMode::Direct), + &[], + ), + ], + ), + PathNode::branch( + &sym::lexical::metadata, + None, + &[ + // entity_temporal_metadata: web_id, entity_uuid, draft_id, entity_edition_id + PathNode::branch( + &sym::lexical::record_id, + Access::Postgres(AccessMode::Composite), + &[ + // entity_temporal_metadata: web_id, entity_uuid, draft_id + PathNode::branch( + &sym::lexical::entity_id, + Access::Postgres(AccessMode::Composite), + &[ + // entity_temporal_metadata.web_id + PathNode::leaf( + &sym::lexical::web_id, + Access::Postgres(AccessMode::Direct), + ), + // entity_temporal_metadata.entity_uuid + PathNode::leaf( + &sym::lexical::entity_uuid, + Access::Postgres(AccessMode::Direct), + ), + // entity_temporal_metadata.draft_id + PathNode::leaf( + &sym::lexical::draft_id, + Access::Postgres(AccessMode::Direct), + ), + ], + ), + // entity_temporal_metadata.entity_edition_id + PathNode::leaf( + &sym::lexical::edition_id, + Access::Postgres(AccessMode::Direct), + ), + ], + ), + // entity_temporal_metadata: decision_time, transaction_time + PathNode::branch( + &sym::lexical::temporal_versioning, + Access::Postgres(AccessMode::Composite), + &[ + // entity_temporal_metadata.decision_time + PathNode::leaf( + &sym::lexical::decision_time, + Access::Postgres(AccessMode::Direct), + ), + // entity_temporal_metadata.transaction_time + PathNode::leaf( + &sym::lexical::transaction_time, + Access::Postgres(AccessMode::Direct), + ), + ], + ), + // entity_is_of_type (via JOIN) + PathNode::leaf( + &sym::lexical::entity_type_ids, + Access::Postgres(AccessMode::Direct), + ), + // entity_editions.archived + PathNode::leaf( + &sym::lexical::archived, + Access::Postgres(AccessMode::Direct), + ), + // entity_editions.confidence + PathNode::leaf( + &sym::lexical::confidence, + Access::Postgres(AccessMode::Direct), + ), + // spans entity_ids.provenance + entity_editions.provenance + PathNode::branch( + &sym::lexical::provenance, + None, + &[ + // entity_ids.provenance (JSONB) + PathNode::jsonb(&sym::lexical::inferred), + // entity_editions.provenance (JSONB) + PathNode::jsonb(&sym::lexical::edition), + ], + ), + // entity_editions.property_metadata (JSONB) + PathNode::jsonb(&sym::lexical::properties), + ], + ), + // contains synthesized draft_id fields + PathNode::branch( + &sym::lexical::link_data, + None, + &[ + // draft_id is synthesized (always None), not stored + PathNode::branch( + &sym::lexical::left_entity_id, + None, + &[ + // entity_has_left_entity -> entity_edge.target_web_id + PathNode::leaf(&sym::lexical::web_id, Access::Postgres(AccessMode::Direct)), + // entity_has_left_entity -> entity_edge.target_entity_uuid + PathNode::leaf( + &sym::lexical::entity_uuid, + Access::Postgres(AccessMode::Direct), + ), + // synthesized, not in entity_edge + PathNode::leaf(&sym::lexical::draft_id, None), + ], + ), + // draft_id is synthesized (always None), not stored + PathNode::branch( + &sym::lexical::right_entity_id, + None, + &[ + // entity_has_right_entity -> entity_edge.target_web_id + PathNode::leaf(&sym::lexical::web_id, Access::Postgres(AccessMode::Direct)), + // entity_has_right_entity -> entity_edge.target_entity_uuid + PathNode::leaf( + &sym::lexical::entity_uuid, + Access::Postgres(AccessMode::Direct), + ), + // synthesized, not in entity_edge + PathNode::leaf(&sym::lexical::draft_id, None), + ], + ), + // entity_edge.confidence (via entity_has_left_entity) + PathNode::leaf( + &sym::lexical::left_entity_confidence, + Access::Postgres(AccessMode::Direct), + ), + // entity_edge.provenance (JSONB, via entity_has_left_entity) + PathNode::jsonb(&sym::lexical::left_entity_provenance), + // entity_edge.confidence (via entity_has_right_entity) + PathNode::leaf( + &sym::lexical::right_entity_confidence, + Access::Postgres(AccessMode::Direct), + ), + // entity_edge.provenance (JSONB, via entity_has_right_entity) + PathNode::jsonb(&sym::lexical::right_entity_provenance), + ], + ), +]); diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/mod.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/mod.rs new file mode 100644 index 00000000000..e7a24e6b0a7 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/mod.rs @@ -0,0 +1,41 @@ +mod entity; +mod trie; + +#[cfg(test)] +mod tests; + +use self::entity::ENTITY_PATHS; +pub(crate) use self::trie::Access; +use crate::body::place::{Projection, ProjectionKind}; + +/// Determines which backend can access an entity field projection. +/// +/// Walks the projection path through the entity schema trie to determine whether the field is +/// stored in Postgres (as a column or JSONB path) or in the embedding store. Returns `None` if +/// the path doesn't map to any supported backend storage. +/// +/// For example: +/// - `entity.properties.foo` → `Some(Access::Postgres(Direct))` (JSONB) +/// - `entity.encodings.vectors` → `Some(Access::Embedding(Direct))` +/// - `entity.metadata.record_id.entity_id.web_id` → `Some(Access::Postgres(Direct))` +pub(crate) fn entity_projection_access(projections: &[Projection<'_>]) -> Option { + let mut node = &ENTITY_PATHS; + + for projection in projections { + if node.children.is_empty() { + return node.otherwise; + } + + let ProjectionKind::FieldByName(name) = projection.kind else { + return node.otherwise; + }; + + let Some(next_node) = node.lookup(name) else { + return node.otherwise; + }; + + node = next_node; + } + + node.access +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/tests.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/tests.rs new file mode 100644 index 00000000000..173257fda61 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/tests.rs @@ -0,0 +1,113 @@ +//! Unit tests for entity projection path lookup. + +use hashql_core::{symbol::sym, r#type::TypeId}; + +use super::{ + entity_projection_access, + trie::{Access, AccessMode}, +}; +use crate::body::place::{Projection, ProjectionKind}; + +/// Helper to create a `FieldByName` projection. +fn proj(name: impl Into>) -> Projection<'static> { + Projection { + kind: ProjectionKind::FieldByName(name.into()), + r#type: TypeId::PLACEHOLDER, + } +} + +/// `[.properties]` → `Access::Postgres(Direct)` (JSONB column). +#[test] +fn properties_is_postgres() { + let projections = &[proj(sym::lexical::properties)]; + let access = entity_projection_access(projections); + + assert_eq!(access, Some(Access::Postgres(AccessMode::Direct))); +} + +/// `[.properties.foo.bar]` → Postgres (JSONB otherwise). +/// +/// JSONB nodes have `otherwise` set, so any sub-path is also Postgres-accessible. +#[test] +fn properties_subpath_is_postgres() { + let projections = &[ + proj(sym::lexical::properties), + proj(sym::lexical::foo), + proj(sym::lexical::bar), + ]; + let access = entity_projection_access(projections); + + assert_eq!(access, Some(Access::Postgres(AccessMode::Direct))); +} + +/// `[.encodings.vectors]` → `Access::Embedding(Direct)`. +#[test] +fn vectors_is_embedding() { + let projections = &[proj(sym::lexical::encodings), proj(sym::lexical::vectors)]; + let access = entity_projection_access(projections); + + assert_eq!(access, Some(Access::Embedding(AccessMode::Direct))); +} + +/// Various metadata paths map to Postgres columns. +#[test] +fn metadata_columns_are_postgres() { + // metadata.archived -> Direct + let projections = &[proj(sym::lexical::metadata), proj(sym::lexical::archived)]; + assert_eq!( + entity_projection_access(projections), + Some(Access::Postgres(AccessMode::Direct)) + ); + + // metadata.record_id -> Composite + let projections = &[proj(sym::lexical::metadata), proj(sym::lexical::record_id)]; + assert_eq!( + entity_projection_access(projections), + Some(Access::Postgres(AccessMode::Composite)) + ); + + // metadata.record_id.entity_id.web_id -> Direct + let projections = &[ + proj(sym::lexical::metadata), + proj(sym::lexical::record_id), + proj(sym::lexical::entity_id), + proj(sym::lexical::web_id), + ]; + assert_eq!( + entity_projection_access(projections), + Some(Access::Postgres(AccessMode::Direct)) + ); + + // metadata.temporal_versioning.decision_time -> Direct + let projections = &[ + proj(sym::lexical::metadata), + proj(sym::lexical::temporal_versioning), + proj(sym::lexical::decision_time), + ]; + assert_eq!( + entity_projection_access(projections), + Some(Access::Postgres(AccessMode::Direct)) + ); +} + +/// `link_data.left_entity_id.draft_id` → `None` (synthesized, not stored). +#[test] +fn link_data_synthesized_is_none() { + let projections = &[ + proj(sym::lexical::link_data), + proj(sym::lexical::left_entity_id), + proj(sym::lexical::draft_id), + ]; + let access = entity_projection_access(projections); + + assert_eq!(access, None); +} + +/// Invalid path like `[.unknown]` → `None`. +#[test] +fn unknown_path_returns_none() { + let projections = &[proj(sym::lexical::unknown)]; + let access = entity_projection_access(projections); + + assert_eq!(access, None); +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/trie.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/trie.rs new file mode 100644 index 00000000000..d9db87d455b --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/lookup/trie.rs @@ -0,0 +1,82 @@ +use hashql_core::symbol::{Symbol, sym}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub(crate) enum AccessMode { + Direct, + Composite, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub(crate) enum Access { + Postgres(AccessMode), + Embedding(AccessMode), +} + +/// A node in the path access trie. +/// +/// Each node represents a field in a path hierarchy and defines: +/// - The field name this node matches (`name`). +/// - What access applies when the path ends at this node (`access`). +/// - What access applies for unknown/deeper paths (`otherwise`). +/// - What children exist for further path traversal. +#[derive(Debug, Copy, Clone)] +pub(crate) struct PathNode { + /// Field name this node matches (empty string for root). + pub name: &'static Symbol<'static>, + /// Access level when the path ends at this node (no more projections). + pub access: Option, + /// Access level for paths beyond known children (e.g., JSONB allows any sub-path). + pub otherwise: Option, + /// Child nodes. + pub children: &'static [Self], +} + +impl PathNode { + pub(crate) const fn root(children: &'static [Self]) -> Self { + Self { + name: &sym::lexical::entity, + access: None, + otherwise: None, + children, + } + } + + pub(crate) const fn leaf( + name: &'static Symbol<'static>, + access: impl [const] Into>, + ) -> Self { + Self { + name, + access: access.into(), + otherwise: None, + children: &[], + } + } + + /// Creates a JSONB node where any sub-path is also Postgres-accessible. + pub(crate) const fn jsonb(name: &'static Symbol<'static>) -> Self { + Self { + name, + access: Some(Access::Postgres(AccessMode::Direct)), + otherwise: Some(Access::Postgres(AccessMode::Direct)), + children: &[], + } + } + + pub(crate) const fn branch( + name: &'static Symbol<'static>, + access: impl [const] Into>, + children: &'static [Self], + ) -> Self { + Self { + name, + access: access.into(), + otherwise: None, + children, + } + } + + pub(crate) fn lookup(&self, name: Symbol<'_>) -> Option<&Self> { + self.children.iter().find(|node| *node.name == name) + } +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/mod.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/mod.rs new file mode 100644 index 00000000000..c7ca3774426 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/mod.rs @@ -0,0 +1,66 @@ +//! Statement placement analysis for MIR execution targets. +//! +//! Determines which MIR statements can be executed on each [`ExecutionTarget`] and assigns costs +//! to supported statements. The execution planner uses these costs to select optimal targets for +//! different parts of a query. +//! +//! Each target has different capabilities: +//! - [`PostgresStatementPlacement`]: Most operations except closures and function calls +//! - [`EmbeddingStatementPlacement`]: Only `encodings.vectors` entity projections +//! - [`InterpreterStatementPlacement`]: All operations (universal fallback) + +use core::alloc::Allocator; + +use hashql_core::heap::Heap; + +#[cfg(test)] +mod tests; + +mod common; +mod embedding; +mod interpret; +mod lookup; +mod postgres; + +pub use self::{ + embedding::EmbeddingStatementPlacement, interpret::InterpreterStatementPlacement, + postgres::PostgresStatementPlacement, +}; +use super::target::ExecutionTarget; +use crate::{ + body::Body, + context::MirContext, + pass::{ + analysis::execution::cost::{StatementCostVec, TraversalCostVec}, + transform::Traversals, + }, +}; + +/// Computes statement placement costs for a specific execution target. +/// +/// Implementations analyze a [`Body`] to determine which statements can be dispatched to their +/// associated [`ExecutionTarget`]. Each statement that can be executed on the target receives a +/// cost; statements that cannot be executed have no cost assigned (`None`). +/// +/// The analysis considers: +/// - Whether each rvalue (operation) is supported by the target +/// - Whether operands flow through supported paths to reach return blocks +/// - Special handling for entity field projections based on storage location +pub trait StatementPlacement<'heap, A: Allocator> { + type Target: ExecutionTarget; + + /// Computes placement costs for `body`. + /// + /// Returns two cost vectors: + /// - Traversal costs: For locals that require backend data fetching + /// - Statement costs: For all statements in the body + /// + /// A `None` cost means the target cannot execute that statement/traversal. + fn statement_placement( + &mut self, + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + traversals: &Traversals<'heap>, + alloc: A, + ) -> (TraversalCostVec<&'heap Heap>, StatementCostVec<&'heap Heap>); +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/postgres/mod.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/postgres/mod.rs new file mode 100644 index 00000000000..5f3482347e9 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/postgres/mod.rs @@ -0,0 +1,246 @@ +use core::{alloc::Allocator, ops::ControlFlow}; + +use hashql_core::{ + heap::Heap, + id::{Id as _, bit_vec::DenseBitSet}, + symbol::sym, + r#type::{ + self, + environment::Environment, + visit::{RecursiveVisitorGuard, Visitor as _}, + }, +}; + +use super::{ + StatementPlacement, + common::{CostVisitor, OnceValue, SupportedAnalysis}, +}; +use crate::{ + body::{ + Body, Source, + constant::Constant, + local::Local, + operand::Operand, + place::Place, + rvalue::{Aggregate, AggregateKind, Binary, RValue, Unary}, + }, + context::MirContext, + pass::{ + analysis::execution::{ + cost::{Cost, StatementCostVec, TraversalCostVec}, + statement_placement::lookup::{Access, entity_projection_access}, + target::Postgres, + }, + transform::Traversals, + }, + visit::Visitor as _, +}; + +#[cfg(test)] +mod tests; + +const fn is_supported_constant(constant: &Constant<'_>) -> bool { + match constant { + Constant::Int(_) | Constant::Primitive(_) | Constant::Unit => true, + Constant::FnPtr(_) => false, + } +} + +fn is_supported_place<'heap>( + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + domain: &DenseBitSet, + place: &Place<'heap>, +) -> bool { + // For GraphReadFilter bodies, local 1 is the filter argument (vertex). Check if the + // projection path maps to a Postgres-accessible field. + if matches!(body.source, Source::GraphReadFilter(_)) && place.local.as_usize() == 1 { + let local_type = body.local_decls[place.local].r#type; + let type_name = context + .env + .r#type(local_type) + .kind + .opaque() + .map_or_else(|| unreachable!(), |opaque| opaque.name); + + if type_name == sym::path::Entity { + return matches!( + entity_projection_access(&place.projections), + Some(Access::Postgres(_)) + ); + } + + unimplemented!("unimplemented lookup for declared type") + } + + domain.contains(place.local) +} + +fn is_supported_operand<'heap>( + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + domain: &DenseBitSet, + operand: &Operand<'heap>, +) -> bool { + match operand { + Operand::Place(place) => is_supported_place(context, body, domain, place), + Operand::Constant(constant) => is_supported_constant(constant), + } +} + +fn is_supported_rvalue<'heap>( + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + domain: &DenseBitSet, + rvalue: &RValue<'heap>, +) -> bool { + match rvalue { + RValue::Load(operand) => is_supported_operand(context, body, domain, operand), + RValue::Binary(Binary { op: _, left, right }) => { + // All MIR binary operations have Postgres equivalents (with type coercion) + is_supported_operand(context, body, domain, left) + && is_supported_operand(context, body, domain, right) + } + RValue::Unary(Unary { op: _, operand }) => { + // All MIR unary operations have Postgres equivalents (with type coercion) + is_supported_operand(context, body, domain, operand) + } + RValue::Aggregate(Aggregate { kind, operands }) => { + if *kind == AggregateKind::Closure { + return false; + } + + // Non-closure aggregates can be constructed as JSONB + operands + .iter() + .all(|operand| is_supported_operand(context, body, domain, operand)) + } + // Query parameters are passed to Postgres + RValue::Input(_) => true, + // Function calls cannot be pushed to Postgres + RValue::Apply(_) => false, + } +} + +struct HasClosureVisitor<'env, 'heap, G = RecursiveVisitorGuard<'heap>> { + env: &'env Environment<'heap>, + guard: G, +} + +impl<'heap, G> r#type::visit::Visitor<'heap> for HasClosureVisitor<'_, 'heap, G> +where + G: AsMut>, +{ + type Filter = r#type::visit::filter::Deep; + type Result = ControlFlow<()>; + + fn env(&self) -> &Environment<'heap> { + self.env + } + + fn visit_type(&mut self, r#type: r#type::Type<'heap>) -> Self::Result { + self.guard.as_mut().with( + |guard, r#type| { + r#type::visit::walk_type( + &mut HasClosureVisitor { + env: self.env, + guard, + }, + r#type, + ) + }, + r#type, + ) + } + + fn visit_closure(&mut self, _: r#type::Type<'heap, r#type::kind::ClosureType>) -> Self::Result { + ControlFlow::Break(()) + } +} + +/// Statement placement for the [`Postgres`] execution target. +/// +/// Supports constants, binary/unary operations, aggregates (except closures), inputs, and entity +/// field projections that map to Postgres columns or JSONB paths. The environment argument is +/// only transferable if it contains no closure types. +pub struct PostgresStatementPlacement<'heap> { + statement_cost: Cost, + type_visitor_guard: RecursiveVisitorGuard<'heap>, +} + +impl Default for PostgresStatementPlacement<'_> { + fn default() -> Self { + Self { + statement_cost: cost!(4), + type_visitor_guard: RecursiveVisitorGuard::default(), + } + } +} + +impl<'heap, A: Allocator + Clone> StatementPlacement<'heap, A> + for PostgresStatementPlacement<'heap> +{ + type Target = Postgres; + + fn statement_placement( + &mut self, + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + traversals: &Traversals<'heap>, + alloc: A, + ) -> (TraversalCostVec<&'heap Heap>, StatementCostVec<&'heap Heap>) { + let traversal_costs = TraversalCostVec::new(body, traversals, context.heap); + let statement_costs = StatementCostVec::new(&body.basic_blocks, context.heap); + + match body.source { + Source::GraphReadFilter(_) => {} + Source::Ctor(_) | Source::Closure(..) | Source::Thunk(..) | Source::Intrinsic(_) => { + return (traversal_costs, statement_costs); + } + } + + let dispatchable = SupportedAnalysis { + body, + context, + is_supported_rvalue, + is_supported_operand, + initialize_boundary: OnceValue::new( + |body: &Body<'heap>, domain: &mut DenseBitSet| { + debug_assert_eq!(body.args, 2); + + // Environment (local 0) is only transferable if it contains no closures + let env_type = body.local_decls[Local::new(0)].r#type; + let has_closure = HasClosureVisitor { + env: context.env, + guard: &mut self.type_visitor_guard, + } + .visit_id(env_type) + .is_break(); + + if has_closure { + domain.remove(Local::new(0)); + } + + // Entity argument (local 1) must be constructed from field projections + domain.remove(Local::new(1)); + }, + ), + } + .finish_in(alloc); + + let mut visitor = CostVisitor { + body, + context, + dispatchable: &dispatchable, + cost: self.statement_cost, + + statement_costs, + traversal_costs, + + is_supported_rvalue, + }; + visitor.visit_body(body); + + (visitor.traversal_costs, visitor.statement_costs) + } +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/postgres/tests.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/postgres/tests.rs new file mode 100644 index 00000000000..adcdcbc9163 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/postgres/tests.rs @@ -0,0 +1,602 @@ +//! Tests for [`PostgresStatementPlacement`]. +#![expect(clippy::min_ident_chars, clippy::similar_names)] + +use hashql_core::{ + heap::Heap, + symbol::sym, + r#type::{TypeId, builder::TypeBuilder, environment::Environment}, +}; +use hashql_diagnostics::DiagnosticIssues; + +use crate::{ + body::{ + Source, + operand::Operand, + terminator::{GraphRead, GraphReadHead, GraphReadTail, TerminatorKind}, + }, + builder::{BodyBuilder, body}, + context::MirContext, + def::DefId, + intern::Interner, + op, + pass::{ + analysis::execution::statement_placement::{ + PostgresStatementPlacement, StatementPlacement as _, + tests::{assert_placement, run_placement}, + }, + transform::Traversals, + }, +}; + +/// Arithmetic and comparison operations work. +/// +/// Tests that `Binary` and `Unary` `RValue`s are supported when operands are constants +/// or come from dispatchable locals. Uses only constants to isolate the operator support. +#[test] +fn binary_unary_ops_supported() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (), vertex: [Opaque sym::path::Entity; ?], x: Int, y: Int, sum: Int, cond: Bool, neg_cond: Bool; + + bb0() { + x = load 10; + y = load 20; + sum = bin.+ x y; + cond = bin.> sum 15; + neg_cond = un.! cond; + return neg_cond; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = PostgresStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "binary_unary_ops_supported", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// Tuple and struct aggregates work (constructed as JSONB in Postgres). +/// +/// Tests that `Aggregate` `RValue`s with `Tuple` and `Struct` kinds are supported. +/// These can be serialized to JSONB in Postgres. +#[test] +fn aggregate_tuple_supported() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (), vertex: [Opaque sym::path::Entity; ?], tup: (Int, Int), s: (a: Int, b: Int), result: Bool; + + bb0() { + tup = tuple 1, 2; + s = struct a: 10, b: 20; + result = load true; + return result; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = PostgresStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "aggregate_tuple_supported", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// `Closure` aggregate returns `None` (closures cannot be pushed to Postgres). +/// +/// Tests that `Aggregate` `RValue`s with `Closure` kind return no cost. +/// Closures contain function pointers which cannot be serialized to Postgres. +#[test] +fn aggregate_closure_rejected() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let def_id = DefId::new(42); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: Int, vertex: [Opaque sym::path::Entity; ?], closure_env: Int, closure: [fn(Int) -> Int], result: Bool; + + bb0() { + closure_env = load env; + closure = closure def_id closure_env; + result = load true; + return result; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = PostgresStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "aggregate_closure_rejected", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// Function calls (`Apply`) are never supported by Postgres. +/// +/// Tests that `RValue::Apply` always returns no cost, regardless of whether its +/// operands are supported. Postgres cannot execute arbitrary function calls. +/// Uses an environment with a simple type (no closure) to ensure the function +/// operand itself is dispatchable, isolating that Apply is the unsupported part. +#[test] +fn apply_rejected() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let def_id = DefId::new(99); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: Int, vertex: [Opaque sym::path::Entity; ?], capture: Int, func: [fn(Int) -> Int], result: Int, cond: Bool; + + bb0() { + capture = load env; + func = closure def_id capture; + result = apply func, 42; + cond = bin.== result 0; + return cond; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = PostgresStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "apply_rejected", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// `RValue::Input` (query parameters) works. +/// +/// Tests that `RValue::Input` is supported. Query parameters are passed to Postgres +/// as bound parameters in prepared statements. +#[test] +fn input_supported() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (), vertex: [Opaque sym::path::Entity; ?], param: Int, result: Bool; + + bb0() { + param = input.load! "threshold"; + result = bin.> param 100; + return result; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = PostgresStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "input_supported", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// Environment argument (local 0) containing closure type is excluded from dispatchable set. +/// +/// Tests that `HasClosureVisitor` correctly detects closure types nested in the environment +/// type and removes local 0 from the dispatchable set. Even accessing non-closure fields +/// of an env that contains closures is rejected. +#[test] +fn env_with_closure_type_rejected() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (Int, [fn(Int) -> Int]), vertex: [Opaque sym::path::Entity; ?], val: Int, result: Bool; + @proj env_int = env.0: Int; + + bb0() { + val = load env_int; + result = bin.== val 42; + return result; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = PostgresStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "env_with_closure_type_rejected", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// Environment argument with simple types is included in dispatchable set. +/// +/// Tests that when the environment type contains no closures, local 0 remains in the +/// dispatchable set and can be accessed. Contrast with `env_with_closure_type_rejected`. +#[test] +fn env_without_closure_accepted() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (Int, Bool), vertex: [Opaque sym::path::Entity; ?], val: Int, result: Bool; + @proj env_int = env.0: Int; + + bb0() { + val = load env_int; + result = bin.== val 42; + return result; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = PostgresStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "env_without_closure_accepted", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// Entity field projections mapping to Postgres columns are supported. +/// +/// Tests that projecting `entity.metadata.archived` returns a cost since `archived` +/// maps to a direct Postgres column in `entity_editions`. +#[test] +fn entity_projection_column() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + // Access entity.metadata.archived which maps to a Postgres column + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (), vertex: [Opaque sym::path::Entity; ?]; + @proj metadata = vertex.metadata: ?, archived = metadata.archived: Bool; + + bb0() { + return archived; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = PostgresStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "entity_projection_column", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// Entity field projections mapping to JSONB paths are supported. +/// +/// Tests that projecting `entity.properties` returns a cost since `properties` +/// maps to a JSONB column in `entity_editions`. +#[test] +fn entity_projection_jsonb() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + // Access entity.properties which maps to JSONB in Postgres + let body = body!(interner, env; [graph::read::filter]@0/2 -> ? { + decl env: (), vertex: [Opaque sym::path::Entity; ?], props: ?; + @proj properties = vertex.properties: ?; + + bb0() { + props = load properties; + return props; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = PostgresStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "entity_projection_jsonb", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// `StorageLive`/`StorageDead` statements get `cost!(0)`. +/// +/// Tests that storage management statements have zero cost even for Postgres, +/// matching the interpreter behavior. +#[test] +fn storage_statements_zero_cost() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Int { + decl env: (), vertex: [Opaque sym::path::Entity; ?], x: Int, y: Int, result: Int; + + bb0() { + let (x.local); + x = load 10; + let (y.local); + y = load 20; + result = bin.+ x y; + drop (x.local); + drop (y.local); + return result; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = PostgresStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "storage_statements_zero_cost", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// If one branch has an unsupported op, local is excluded from dispatchable set (must analysis). +/// +/// Diamond CFG: both branches converge at bb3. One branch (bb1) assigns `x` via a supported +/// operation (`load 42`), while the other (bb2) assigns `x` via an unsupported operation +/// (`apply`). The must analysis requires that a local be supported on ALL paths to be in +/// the dispatchable set. Since bb2's assignment is unsupported, `x` at bb3 cannot be +/// guaranteed dispatchable. +/// +/// Uses env without closures so that local 0 is in the dispatchable set, isolating the +/// must-analysis behavior from closure exclusion. +#[test] +fn diamond_must_analysis() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let def_id = DefId::new(77); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Bool { + decl env: (Int), vertex: [Opaque sym::path::Entity; ?], cond: Bool, capture: (Int), func: [fn(Int) -> Int], x: Int, result: Bool; + + bb0() { + cond = load true; + capture = load env; + func = closure def_id capture; + if cond then bb1() else bb2(); + }, + bb1() { + x = load 42; + goto bb3(x); + }, + bb2() { + x = apply func, 1; + goto bb3(x); + }, + bb3(x) { + result = bin.== x 0; + return result; + } + }); + + let mut context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let mut placement = PostgresStatementPlacement::default(); + let (body, statement_costs, traversal_costs) = + run_placement(&mut context, &mut placement, body); + + assert_placement( + "diamond_must_analysis", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} + +/// Values flowing through `GraphRead` edges become unsupported. +/// +/// Tests that `transfer_graph_read_edge` correctly marks target block parameters as +/// non-dispatchable. Graph reads must be executed by the interpreter, so any value +/// produced by a graph read cannot be pushed to Postgres. +/// +/// Uses fluent builder API because `GraphRead` terminator is not supported by the `body!` macro. +#[test] +fn graph_read_edge_unsupported() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let int_ty = TypeBuilder::synthetic(&env).integer(); + let bool_ty = TypeBuilder::synthetic(&env).boolean(); + let unit_ty = TypeBuilder::synthetic(&env).tuple([] as [TypeId; 0]); + let entity_ty = TypeBuilder::synthetic(&env).opaque(sym::path::Entity, bool_ty); + + let mut builder = BodyBuilder::new(&interner); + + let _env_local = builder.local("env", unit_ty); + let vertex = builder.local("vertex", entity_ty); + let axis = builder.local("axis", int_ty); + let graph_result = builder.local("graph_result", int_ty); + let local_val = builder.local("local_val", int_ty); + let sum = builder.local("sum", int_ty); + let result = builder.local("result", bool_ty); + + let const_10 = builder.const_int(10); + let const_0 = builder.const_int(0); + + let bb0 = builder.reserve_block([]); + let bb1 = builder.reserve_block([graph_result.local]); + + builder + .build_block(bb0) + .assign_place(axis, |rv| rv.load(const_10)) + .assign_place(local_val, |rv| rv.load(const_10)) + .finish_with_terminator(TerminatorKind::GraphRead(GraphRead { + head: GraphReadHead::Entity { + axis: Operand::Place(axis), + }, + body: Vec::new_in(&heap), + tail: GraphReadTail::Collect, + target: bb1, + })); + + builder + .build_block(bb1) + .assign_place(sum, |rv| rv.binary(graph_result, op![+], local_val)) + .assign_place(result, |rv| rv.binary(sum, op![>], const_0)) + .ret(result); + + let mut body = builder.finish(2, bool_ty); + body.source = Source::GraphReadFilter(hashql_hir::node::HirId::PLACEHOLDER); + + let context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let traversals = Traversals::with_capacity_in(vertex.local, body.local_decls.len(), &heap); + + let mut placement = PostgresStatementPlacement::default(); + let (traversal_costs, statement_costs) = + placement.statement_placement(&context, &body, &traversals, &heap); + + assert_placement( + "graph_read_edge_unsupported", + "postgres", + &body, + &context, + &statement_costs, + &traversal_costs, + ); +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs new file mode 100644 index 00000000000..a2f949c754c --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs @@ -0,0 +1,193 @@ +//! Shared test harness for statement placement analysis. +#![expect(clippy::min_ident_chars)] + +use alloc::alloc::Global; +use core::{alloc::Allocator, fmt::Display}; +use std::{io::Write as _, path::PathBuf}; + +use hashql_core::{ + heap::Heap, + pretty::Formatter, + r#type::{TypeFormatter, TypeFormatterOptions, environment::Environment}, +}; +use hashql_diagnostics::DiagnosticIssues; +use insta::{Settings, assert_snapshot}; + +use super::StatementPlacement; +use crate::{ + body::{Body, local::Local, location::Location, statement::Statement}, + builder::body, + context::MirContext, + intern::Interner, + pass::{ + Changed, TransformPass as _, + analysis::execution::{ + cost::{StatementCostVec, TraversalCostVec}, + statement_placement::{EmbeddingStatementPlacement, PostgresStatementPlacement}, + }, + transform::{TraversalExtraction, Traversals}, + }, + pretty::{TextFormatAnnotations, TextFormatOptions}, +}; + +/// Annotation provider that displays statement costs as trailing comments. +struct CostAnnotations<'costs, A: Allocator> { + costs: &'costs StatementCostVec, +} + +impl TextFormatAnnotations for CostAnnotations<'_, A> { + type StatementAnnotation<'this, 'heap> + = impl Display + where + Self: 'this; + + fn annotate_statement<'heap>( + &self, + location: Location, + _statement: &Statement<'heap>, + ) -> Option> { + let cost = self.costs.get(location)?; + + Some(core::fmt::from_fn(move |fmt| write!(fmt, "cost: {cost}"))) + } +} + +/// Formats traversal costs as a summary section. +fn format_traversals(traversal_costs: &TraversalCostVec) -> impl Display { + core::fmt::from_fn(move |f| { + writeln!(f, "Traversals:")?; + for (local, cost) in traversal_costs { + writeln!(f, " {local}: {cost}")?; + } + Ok(()) + }) +} + +/// Runs statement placement analysis and asserts the result matches a snapshot. +#[track_caller] +pub(crate) fn assert_placement<'heap, A: Allocator>( + name: &'static str, + snapshot_subdir: &str, + body: &Body<'heap>, + context: &MirContext<'_, 'heap>, + statement_costs: &StatementCostVec, + traversal_costs: &TraversalCostVec, +) { + let formatter = Formatter::new(context.heap); + let type_formatter = TypeFormatter::new(&formatter, context.env, TypeFormatterOptions::terse()); + + let annotations = CostAnnotations { + costs: statement_costs, + }; + + let mut text_format = TextFormatOptions { + writer: Vec::::new(), + indent: 4, + sources: (), + types: type_formatter, + annotations, + } + .build(); + + text_format.format_body(body).expect("formatting failed"); + + write!( + text_format.writer, + "\n\n{:=^50}\n\n", + format!(" Traversals ") + ) + .expect("infallible"); + + write!(text_format.writer, "{}", format_traversals(traversal_costs)) + .expect("formatting failed"); + + // Snapshot configuration + let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let mut settings = Settings::clone_current(); + settings.set_snapshot_path(dir.join(format!( + "tests/ui/pass/execution/statement_placement/{snapshot_subdir}" + ))); + settings.set_prepend_module_to_snapshot(false); + + let _guard = settings.bind_to_scope(); + + let output = String::from_utf8_lossy(&text_format.writer); + assert_snapshot!(name, output); +} + +/// Helper to set up a test context and run placement analysis. +/// +/// Returns the body, context components, and cost vectors for assertion. +#[track_caller] +pub(crate) fn run_placement<'heap>( + context: &mut MirContext<'_, 'heap>, + placement: &mut impl StatementPlacement<'heap, Global>, + mut body: Body<'heap>, +) -> ( + Body<'heap>, + StatementCostVec<&'heap Heap>, + TraversalCostVec<&'heap Heap>, +) { + // Run TraversalExtraction to produce Traversals + let mut extraction = TraversalExtraction::new_in(Global); + let _: Changed = extraction.run(context, &mut body); + let traversals = extraction + .take_traversals() + .expect("expected GraphReadFilter body"); + + // Run placement analysis + let (traversal_costs, statement_costs) = + placement.statement_placement(context, &body, &traversals, Global); + + (body, statement_costs, traversal_costs) +} + +// ============================================================================= +// Shared Tests +// ============================================================================= + +/// Non-`GraphReadFilter` sources return empty costs for Postgres and Embedding. +/// +/// Tests that only `Source::GraphReadFilter` bodies produce placement costs. +/// Other sources (Closure, Thunk, Ctor, Intrinsic) should return empty cost vectors +/// for specialized backends, though Interpreter still assigns costs. +#[test] +fn non_graph_read_filter_returns_empty() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + // Use a closure source instead of GraphReadFilter + let body = body!(interner, env; fn@0/0 -> Int { + decl x: Int, y: Int, result: Int; + + bb0() { + x = load 10; + y = load 20; + result = bin.+ x y; + return result; + } + }); + + let context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let traversals = Traversals::with_capacity_in(Local::new(1), body.local_decls.len(), &heap); + + let mut postgres = PostgresStatementPlacement::default(); + let mut embedding = EmbeddingStatementPlacement::default(); + + let (postgres_traversal, postgres_statement) = + postgres.statement_placement(&context, &body, &traversals, &heap); + let (embedding_traversal, embedding_statement) = + embedding.statement_placement(&context, &body, &traversals, &heap); + + assert_eq!(postgres_traversal.iter().count(), 0); + assert!(postgres_statement.is_empty()); + assert_eq!(embedding_traversal.iter().count(), 0); + assert!(embedding_statement.is_empty()); +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/execution/target.rs b/libs/@local/hashql/mir/src/pass/analysis/execution/target.rs new file mode 100644 index 00000000000..fb965937f92 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/analysis/execution/target.rs @@ -0,0 +1,96 @@ +use hashql_core::id; + +use crate::pass::simplify_type_name; + +id::newtype!( + pub struct TargetId(u8 is 0..=0xF0) +); + +impl TargetId { + pub const EMBEDDING: Self = Self(0x02); + pub const INTERPRETER: Self = Self(0x01); + pub const POSTGRES: Self = Self(0x00); +} + +/// A backend capable of executing MIR statements. +/// +/// Each target represents a different execution environment with its own capabilities and +/// performance characteristics. The execution planner uses statement placement analysis to +/// determine which statements each target can handle, then selects the optimal target based on +/// cost. +/// +/// Currently supported targets: +/// - [`Postgres`]: Pushes operations into SQL for database-side execution +/// - [`Embedding`]: Routes vector operations to the embedding store +/// - [`Interpreter`]: Executes in the HashQL runtime (universal fallback) +pub trait ExecutionTarget { + /// Returns the unique identifier for this target. + /// + /// Used to distinguish between targets when comparing costs or storing per-target data. + fn id(&self) -> TargetId; + + /// Returns a human-readable name for this target. + /// + /// By default derived from the type name; used for diagnostics and debugging output. + fn name(&self) -> &str { + const { simplify_type_name(core::any::type_name::()) } + } +} + +/// Execution target that pushes operations into SQL queries. +/// +/// Supports constants, arithmetic, comparisons, and entity field access for columns stored in +/// Postgres. Operations are translated to SQL and executed database-side, avoiding data transfer +/// overhead. +pub struct Postgres; + +impl ExecutionTarget for Postgres { + fn id(&self) -> TargetId { + TargetId::POSTGRES + } +} + +/// Execution target for the HashQL runtime interpreter. +/// +/// The universal fallback that can execute any MIR statement. Used when specialized targets +/// cannot handle an operation, or when the operation requires runtime features like closures. +pub struct Interpreter; + +impl ExecutionTarget for Interpreter { + fn id(&self) -> TargetId { + TargetId::INTERPRETER + } +} + +/// Execution target for vector embedding operations. +/// +/// Routes vector similarity searches and embedding lookups to the dedicated embedding store. +/// Only supports accessing the `encodings.vectors` path on entities. +pub struct Embedding; + +impl ExecutionTarget for Embedding { + fn id(&self) -> TargetId { + TargetId::EMBEDDING + } +} + +#[cfg(test)] +mod tests { + use super::{Embedding, ExecutionTarget as _, Interpreter, Postgres, TargetId}; + + /// Target IDs are distinct and don't collide. + #[test] + fn target_ids_are_distinct() { + assert_ne!(TargetId::POSTGRES, TargetId::INTERPRETER); + assert_ne!(TargetId::POSTGRES, TargetId::EMBEDDING); + assert_ne!(TargetId::INTERPRETER, TargetId::EMBEDDING); + } + + /// Target `name()` returns expected strings derived from type names. + #[test] + fn target_names_derived_correctly() { + assert_eq!(Postgres.name(), "Postgres"); + assert_eq!(Interpreter.name(), "Interpreter"); + assert_eq!(Embedding.name(), "Embedding"); + } +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/mod.rs b/libs/@local/hashql/mir/src/pass/analysis/mod.rs index 0dbc94ecce8..f1c6a27f99c 100644 --- a/libs/@local/hashql/mir/src/pass/analysis/mod.rs +++ b/libs/@local/hashql/mir/src/pass/analysis/mod.rs @@ -1,6 +1,7 @@ mod callgraph; mod data_dependency; pub mod dataflow; +pub mod execution; pub mod size_estimation; pub use self::{ callgraph::{CallGraph, CallGraphAnalysis, CallKind, CallSite}, diff --git a/libs/@local/hashql/mir/src/pass/mod.rs b/libs/@local/hashql/mir/src/pass/mod.rs index 24f3cda10c5..870cd726875 100644 --- a/libs/@local/hashql/mir/src/pass/mod.rs +++ b/libs/@local/hashql/mir/src/pass/mod.rs @@ -422,9 +422,45 @@ pub trait AnalysisPass<'env, 'heap> { } } +/// A global analysis pass over MIR. +/// +/// Unlike [`AnalysisPass`] which operates on a single [`Body`], global analysis passes have +/// access to **all** bodies simultaneously via a [`DefIdSlice`]. This enables inter-procedural +/// analyses that need to: +/// +/// - Build and traverse the call graph +/// - Gather cross-function statistics or diagnostics +/// +/// # When to Use +/// +/// Use `GlobalAnalysisPass` when your analysis requires visibility across multiple functions. +/// For single-function analyses, prefer [`AnalysisPass`] which is simpler and allows the pass +/// manager more flexibility in scheduling. +/// +/// # Implementing a Global Analysis Pass +/// +/// ```ignore +/// struct CallGraphAnalysis; +/// +/// impl<'env, 'heap> GlobalAnalysisPass<'env, 'heap> for CallGraphAnalysis { +/// fn run(&mut self, context: &mut MirContext<'env, 'heap>, bodies: &DefIdSlice>) { +/// // Analyze relationships between functions, report diagnostics, etc. +/// } +/// } +/// ``` +/// +/// [`name`]: GlobalAnalysisPass::name pub trait GlobalAnalysisPass<'env, 'heap> { + /// Executes the analysis pass on all bodies. + /// + /// The `context` provides access to the heap allocator, type environment, interner, and + /// diagnostic collection. The `bodies` slice allows reading any function body. fn run(&mut self, context: &mut MirContext<'env, 'heap>, bodies: &DefIdSlice>); + /// Returns a human-readable name for this pass. + /// + /// The default implementation extracts the type name without module path or generic + /// parameters. Override this method to provide a custom name. fn name(&self) -> &'static str { const { simplify_type_name(core::any::type_name::()) } } diff --git a/libs/@local/hashql/mir/src/pass/transform/administrative_reduction/tests.rs b/libs/@local/hashql/mir/src/pass/transform/administrative_reduction/tests.rs index b9c4d372484..41897c0d07a 100644 --- a/libs/@local/hashql/mir/src/pass/transform/administrative_reduction/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/administrative_reduction/tests.rs @@ -19,7 +19,7 @@ use crate::{ def::{DefId, DefIdSlice}, intern::Interner, pass::{Changed, GlobalTransformPass as _, GlobalTransformState}, - pretty::TextFormat, + pretty::TextFormatOptions, }; /// Tests `TrivialThunk` classification for an identity function (returns parameter). @@ -253,12 +253,14 @@ fn assert_admin_reduction_pass<'heap>( context.env, TypeFormatterOptions::terse().with_qualified_opaque_names(true), ); - let mut text_format = TextFormat { + let mut text_format = TextFormatOptions { writer: Vec::new(), indent: 4, sources: (), types: &mut formatter, - }; + annotations: (), + } + .build(); text_format .format(DefIdSlice::from_raw(bodies), &[]) diff --git a/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/tests.rs b/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/tests.rs index a77d77fd139..17264a8da00 100644 --- a/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/cfg_simplify/tests.rs @@ -20,7 +20,7 @@ use crate::{ error::MirDiagnosticCategory, intern::Interner, pass::{TransformPass as _, transform::error::TransformationDiagnosticCategory}, - pretty::TextFormat, + pretty::TextFormatOptions, }; #[track_caller] @@ -35,12 +35,14 @@ fn assert_cfg_simplify_pass<'heap>( context.env, TypeFormatterOptions::terse().with_qualified_opaque_names(true), ); - let mut text_format = TextFormat { + let mut text_format = TextFormatOptions { writer: Vec::new(), indent: 4, sources: (), types: &mut formatter, - }; + annotations: (), + } + .build(); let mut bodies = [body]; diff --git a/libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs b/libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs index 862349912a8..a836ea7e6a1 100644 --- a/libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/copy_propagation/tests.rs @@ -23,7 +23,7 @@ use crate::{ def::DefIdSlice, intern::Interner, pass::TransformPass as _, - pretty::TextFormat, + pretty::TextFormatOptions, }; #[track_caller] @@ -38,12 +38,14 @@ fn assert_cp_pass<'heap>( context.env, TypeFormatterOptions::terse().with_qualified_opaque_names(true), ); - let mut text_format = TextFormat { + let mut text_format = TextFormatOptions { writer: Vec::new(), indent: 4, sources: (), types: &mut formatter, - }; + annotations: (), + } + .build(); let mut bodies = [body]; diff --git a/libs/@local/hashql/mir/src/pass/transform/dbe/tests.rs b/libs/@local/hashql/mir/src/pass/transform/dbe/tests.rs index c3d0f40d8b8..301917cf175 100644 --- a/libs/@local/hashql/mir/src/pass/transform/dbe/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/dbe/tests.rs @@ -13,7 +13,7 @@ use insta::{Settings, assert_snapshot}; use super::DeadBlockElimination; use crate::{ body::Body, builder::body, context::MirContext, def::DefIdSlice, intern::Interner, - pass::TransformPass as _, pretty::TextFormat, + pass::TransformPass as _, pretty::TextFormatOptions, }; #[track_caller] @@ -28,12 +28,14 @@ fn assert_dbe_pass<'heap>( context.env, TypeFormatterOptions::terse().with_qualified_opaque_names(true), ); - let mut text_format = TextFormat { + let mut text_format = TextFormatOptions { writer: Vec::new(), indent: 4, sources: (), types: &mut formatter, - }; + annotations: (), + } + .build(); let mut bodies = [body]; diff --git a/libs/@local/hashql/mir/src/pass/transform/dle/tests.rs b/libs/@local/hashql/mir/src/pass/transform/dle/tests.rs index 055ec3fd03d..80f306c9dc3 100644 --- a/libs/@local/hashql/mir/src/pass/transform/dle/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/dle/tests.rs @@ -19,7 +19,7 @@ use crate::{ def::DefIdSlice, intern::Interner, pass::TransformPass as _, - pretty::TextFormat, + pretty::TextFormatOptions, }; #[track_caller] @@ -34,12 +34,14 @@ fn assert_dle_pass<'heap>( context.env, TypeFormatterOptions::terse().with_qualified_opaque_names(true), ); - let mut text_format = TextFormat { + let mut text_format = TextFormatOptions { writer: Vec::new(), indent: 4, sources: (), types: &mut formatter, - }; + annotations: (), + } + .build(); let mut bodies = [body]; diff --git a/libs/@local/hashql/mir/src/pass/transform/dse/tests.rs b/libs/@local/hashql/mir/src/pass/transform/dse/tests.rs index 0debc6e320e..22ae8d3733e 100644 --- a/libs/@local/hashql/mir/src/pass/transform/dse/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/dse/tests.rs @@ -23,7 +23,7 @@ use crate::{ def::DefIdSlice, intern::Interner, pass::TransformPass as _, - pretty::TextFormat, + pretty::TextFormatOptions, }; #[track_caller] @@ -38,12 +38,14 @@ fn assert_dse_pass<'heap>( context.env, TypeFormatterOptions::terse().with_qualified_opaque_names(true), ); - let mut text_format = TextFormat { + let mut text_format = TextFormatOptions { writer: Vec::new(), indent: 4, sources: (), types: &mut formatter, - }; + annotations: (), + } + .build(); let mut bodies = [body]; diff --git a/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs b/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs index a0fa3590e85..13dec06b17a 100644 --- a/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs @@ -30,7 +30,7 @@ use crate::{ BodyProperties, Candidate, analysis::InlineDirective, heuristics::InlineHeuristics, }, }, - pretty::TextFormat, + pretty::TextFormatOptions, }; /// Creates an identity function: `fn(x: Int) -> Int { return x; }`. @@ -84,12 +84,14 @@ fn format_bodies<'heap>( context.env, TypeFormatterOptions::terse().with_qualified_opaque_names(true), ); - let mut text_format = TextFormat { + let mut text_format = TextFormatOptions { writer: Vec::new(), indent: 4, sources: (), types: &mut formatter, - }; + annotations: (), + } + .build(); text_format .format(bodies, &[]) diff --git a/libs/@local/hashql/mir/src/pass/transform/inst_simplify/tests.rs b/libs/@local/hashql/mir/src/pass/transform/inst_simplify/tests.rs index 40bd831ed4b..20348fa8e51 100644 --- a/libs/@local/hashql/mir/src/pass/transform/inst_simplify/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/inst_simplify/tests.rs @@ -25,7 +25,7 @@ use crate::{ def::DefIdSlice, intern::Interner, pass::{Changed, TransformPass as _}, - pretty::TextFormat, + pretty::TextFormatOptions, }; #[track_caller] @@ -40,12 +40,14 @@ fn assert_inst_simplify_pass<'heap>( context.env, TypeFormatterOptions::terse().with_qualified_opaque_names(true), ); - let mut text_format = TextFormat { + let mut text_format = TextFormatOptions { writer: Vec::new(), indent: 4, sources: (), types: &mut formatter, - }; + annotations: (), + } + .build(); let mut bodies = [body]; diff --git a/libs/@local/hashql/mir/src/pass/transform/ssa_repair/tests.rs b/libs/@local/hashql/mir/src/pass/transform/ssa_repair/tests.rs index 352ce40a40f..58f938054d6 100644 --- a/libs/@local/hashql/mir/src/pass/transform/ssa_repair/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/ssa_repair/tests.rs @@ -17,7 +17,7 @@ use crate::{ def::DefIdSlice, intern::Interner, pass::{TransformPass as _, transform::ssa_repair::SsaRepair}, - pretty::TextFormat, + pretty::TextFormatOptions, }; #[track_caller] @@ -32,12 +32,14 @@ fn assert_ssa_pass<'heap>( context.env, TypeFormatterOptions::terse().with_qualified_opaque_names(true), ); - let mut text_format = TextFormat { + let mut text_format = TextFormatOptions { writer: Vec::new(), indent: 4, sources: (), types: &mut formatter, - }; + annotations: (), + } + .build(); let mut bodies = [body]; diff --git a/libs/@local/hashql/mir/src/pass/transform/traversal_extraction/mod.rs b/libs/@local/hashql/mir/src/pass/transform/traversal_extraction/mod.rs index 0396e1090b4..919256733ee 100644 --- a/libs/@local/hashql/mir/src/pass/transform/traversal_extraction/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/traversal_extraction/mod.rs @@ -69,7 +69,11 @@ mod tests; use core::{alloc::Allocator, convert::Infallible}; -use hashql_core::{heap::Heap, id::Id as _, span::SpanId}; +use hashql_core::{ + heap::Heap, + id::{Id as _, bit_vec::DenseBitSet}, + span::SpanId, +}; use crate::{ body::{ @@ -101,7 +105,7 @@ pub struct Traversals<'heap> { } impl<'heap> Traversals<'heap> { - fn with_capacity_in(source: Local, capacity: usize, heap: &'heap Heap) -> Self { + pub(crate) fn with_capacity_in(source: Local, capacity: usize, heap: &'heap Heap) -> Self { Self { source, derivations: LocalVec::with_capacity_in(capacity, heap), @@ -120,6 +124,19 @@ impl<'heap> Traversals<'heap> { pub fn lookup(&self, local: Local) -> Option<&Place<'heap>> { self.derivations.lookup(local) } + + #[must_use] + pub fn enabled(&self, body: &Body<'heap>) -> DenseBitSet { + let mut set = DenseBitSet::new_empty(body.local_decls.len()); + + for (local, place) in self.derivations.iter_enumerated() { + if place.is_some() { + set.insert(local); + } + } + + set + } } /// Visitor that extracts projections from a target local into separate bindings. diff --git a/libs/@local/hashql/mir/src/pass/transform/traversal_extraction/tests.rs b/libs/@local/hashql/mir/src/pass/transform/traversal_extraction/tests.rs index f7791916252..0b227dbb1ed 100644 --- a/libs/@local/hashql/mir/src/pass/transform/traversal_extraction/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/traversal_extraction/tests.rs @@ -19,7 +19,7 @@ use crate::{ def::DefIdSlice, intern::Interner, pass::{TransformPass as _, transform::traversal_extraction::TraversalExtraction}, - pretty::TextFormat, + pretty::TextFormatOptions, }; #[track_caller] @@ -34,12 +34,14 @@ fn assert_traversal_pass<'heap>( context.env, TypeFormatterOptions::terse().with_qualified_opaque_names(true), ); - let mut text_format = TextFormat { + let mut text_format = TextFormatOptions { writer: Vec::new(), indent: 4, sources: (), types: &mut formatter, - }; + annotations: (), + } + .build(); let mut bodies = [body]; diff --git a/libs/@local/hashql/mir/src/pretty/d2.rs b/libs/@local/hashql/mir/src/pretty/d2.rs index 92aa990a0d9..b3bec2898d5 100644 --- a/libs/@local/hashql/mir/src/pretty/d2.rs +++ b/libs/@local/hashql/mir/src/pretty/d2.rs @@ -7,17 +7,26 @@ use std::io; use bstr::ByteSlice as _; use hashql_core::{pretty::RenderFormat, r#type::TypeFormatter}; -use super::{DataFlowLookup, FormatPart, SourceLookup, TextFormat, text::HighlightBody}; +use super::{ + DataFlowLookup, FormatPart, SourceLookup, TextFormat, + text::{HighlightBody, TextFormatOptions}, +}; use crate::{ body::{ Body, basic_block::{BasicBlock, BasicBlockId}, + location::Location, terminator::{Goto, GraphRead, SwitchInt, TerminatorKind}, }, def::{DefId, DefIdSlice}, pretty::text::{Signature, SignatureOptions, TargetParams, TerminatorHead}, }; +/// A double buffer used for HTML escaping during D2 output generation. +/// +/// This buffer uses a front/back swap pattern to efficiently perform multiple +/// string replacements (escaping `&`, `<`, `>`, and newlines) without repeated +/// allocations. #[derive(Debug, Default)] pub struct D2Buffer { front: Vec, @@ -93,21 +102,25 @@ where fn format_text_unescaped(&mut self, value: V) -> io::Result<()> where - for<'a> TextFormat<&'a mut W, &'a S, &'a mut TypeFormatter<'fmt, 'fmt, 'heap>>: + for<'a> TextFormat<&'a mut W, &'a S, &'a mut TypeFormatter<'fmt, 'fmt, 'heap>, ()>: FormatPart, { - TextFormat { + let mut text = TextFormatOptions { writer: &mut self.writer, indent: 0, sources: &self.sources, types: &mut self.types, + annotations: (), } - .format_part(value) + .build(); + + text.format_part(value)?; + text.flush() } fn format_text(&mut self, value: V) -> io::Result<()> where - for<'a> TextFormat<&'a mut Vec, &'a S, &'a mut TypeFormatter<'fmt, 'fmt, 'heap>>: + for<'a> TextFormat<&'a mut Vec, &'a S, &'a mut TypeFormatter<'fmt, 'fmt, 'heap>, ()>: FormatPart, { const REPLACEMENTS: [(u8, &[u8]); 4] = [ @@ -119,13 +132,17 @@ where ]; self.buffer.clear(); - TextFormat { + let mut text = TextFormatOptions { writer: &mut self.buffer.front, indent: 0, sources: &self.sources, types: &mut self.types, + annotations: (), } - .format_part(value)?; + .build(); + + text.format_part(value)?; + text.flush()?; self.buffer.back.reserve(self.buffer.front.len()); @@ -151,7 +168,7 @@ where aux: impl IntoIterator, ) -> io::Result<()> where - for<'a> TextFormat<&'a mut Vec, &'a S, &'a mut TypeFormatter<'fmt, 'fmt, 'heap>>: + for<'a> TextFormat<&'a mut Vec, &'a S, &'a mut TypeFormatter<'fmt, 'fmt, 'heap>, ()>: FormatPart, { let valign = if valign_bottom { "bottom" } else { "top" }; @@ -365,7 +382,18 @@ where for (index, statement) in block.statements.iter().enumerate() { let aux = self.dataflow.on_statement(def_id, block_id, index); - self.write_row(false, index, statement, aux)?; + self.write_row( + false, + index, + ( + Location { + block: block_id, + statement_index: index + 1, + }, + statement, + ), + aux, + )?; } self.write_row( diff --git a/libs/@local/hashql/mir/src/pretty/mod.rs b/libs/@local/hashql/mir/src/pretty/mod.rs index e217c091d5b..6a591b86f47 100644 --- a/libs/@local/hashql/mir/src/pretty/mod.rs +++ b/libs/@local/hashql/mir/src/pretty/mod.rs @@ -22,7 +22,7 @@ mod d2; mod text; pub use d2::{D2Buffer, D2Format}; -pub use text::TextFormat; +pub use text::{TextFormat, TextFormatAnnotations, TextFormatOptions}; /// A trait for looking up source information associated with function definitions. /// diff --git a/libs/@local/hashql/mir/src/pretty/text.rs b/libs/@local/hashql/mir/src/pretty/text.rs index d1732724995..d588455784c 100644 --- a/libs/@local/hashql/mir/src/pretty/text.rs +++ b/libs/@local/hashql/mir/src/pretty/text.rs @@ -1,6 +1,7 @@ // Textual representation of bodies, based on a similar syntax used by rustc -use std::io; +use core::fmt::Display; +use std::io::{self, Write as _}; use hashql_core::{ id::Id as _, @@ -17,7 +18,8 @@ use crate::{ Body, Source, basic_block::{BasicBlock, BasicBlockId}, constant::Constant, - local::Local, + local::{Local, LocalDecl}, + location::Location, operand::Operand, place::{Place, ProjectionKind}, rvalue::{Aggregate, AggregateKind, Apply, Binary, Input, RValue, Unary}, @@ -40,18 +42,87 @@ const fn source_keyword(source: Source<'_>) -> &'static str { } } +/// Configuration options for formatting function signatures. pub(crate) struct SignatureOptions { + /// The output format (plain text or HTML fragment). pub format: RenderFormat, } /// A wrapper for formatting function signatures from MIR bodies. +/// +/// Renders the function keyword, name, parameter list with types, and return type. pub(crate) struct Signature<'body, 'heap>(pub &'body Body<'heap>, pub SignatureOptions); /// A helper struct for formatting key-value pairs with consistent syntax. struct KeyValuePair(K, V); +/// A set of definition IDs to visually highlight in formatted output. +/// +/// When formatting multiple MIR bodies, those with IDs in this set will be +/// marked distinctly (e.g., prefixed with `*` in text output or colored in D2). pub(crate) struct HighlightBody<'def>(pub &'def [DefId]); +/// A trait for providing inline annotations during text formatting. +/// +/// Implementations can attach comments or annotations to statements and local +/// declarations in the formatted output. Annotations appear as trailing comments +/// (e.g., `// annotation`) after the formatted line. +pub trait TextFormatAnnotations { + /// The type of annotation displayed after statements. + type StatementAnnotation<'this, 'heap>: Display + = ! + where + Self: 'this; + + /// The type of annotation displayed after local declarations. + type DeclarationAnnotation<'this, 'heap>: Display + = ! + where + Self: 'this; + + /// Returns an optional annotation for the given statement at `location`. + #[expect(unused_variables, reason = "trait definition")] + fn annotate_statement<'heap>( + &self, + location: Location, + statement: &Statement<'heap>, + ) -> Option> { + None + } + + /// Returns an optional annotation for the given local declaration. + #[expect(unused_variables, reason = "trait definition")] + fn annotate_local_decl<'heap>( + &self, + local: Local, + declaration: &LocalDecl<'heap>, + ) -> Option> { + None + } +} + +impl TextFormatAnnotations for () {} + +/// Configuration for constructing a [`TextFormat`] formatter. +pub struct TextFormatOptions { + /// The writer where formatted text will be written. + pub writer: W, + /// Number of spaces per indentation level. + pub indent: usize, + /// Source lookup for resolving symbols and identifiers. + pub sources: S, + /// Type formatter for rendering type information. + pub types: T, + /// Annotation provider for adding inline comments. + pub annotations: A, +} + +impl TextFormatOptions { + pub fn build(self) -> TextFormat { + TextFormat::new(self) + } +} + /// A text-based formatter for MIR (Middle Intermediate Representation) structures. /// /// This formatter converts MIR components into human-readable text representation, @@ -64,18 +135,42 @@ pub(crate) struct HighlightBody<'def>(pub &'def [DefId]); /// - `W`: A writer implementing [`io::Write`] for text output /// - `S`: A source lookup implementing [`SourceLookup`] for symbol resolution /// - `T`: A type which implements [`AsMut`] for type information -pub struct TextFormat { +pub struct TextFormat { /// The writer where formatted text will be written. pub writer: W, /// Amount of indention per level. - pub indent: usize, + indent: usize, /// Source lookup for resolving symbols and identifiers. - pub sources: S, + sources: S, /// Type formatter for formatting type information. - pub types: T, + types: T, + annotations: A, + + line_buffer: Vec, +} + +impl TextFormat { + pub fn new( + TextFormatOptions { + writer, + indent, + sources, + types, + annotations, + }: TextFormatOptions, + ) -> Self { + Self { + writer, + indent, + sources, + types, + annotations, + line_buffer: Vec::new(), + } + } } -impl TextFormat +impl TextFormat where W: io::Write, { @@ -96,8 +191,10 @@ where where S: SourceLookup<'heap>, T: AsMut>, + A: TextFormatAnnotations, { - self.format_part((bodies, HighlightBody(highlight))) + self.format_part((bodies, HighlightBody(highlight)))?; + self.flush() } /// Formats a single MIR body as human-readable text. @@ -117,8 +214,10 @@ where where S: SourceLookup<'heap>, T: AsMut>, + A: TextFormatAnnotations, { - self.format_part((body, BodyRenderOptions { highlight: false })) + self.format_part((body, BodyRenderOptions { highlight: false }))?; + self.flush() } fn separated_list( @@ -137,7 +236,7 @@ where self.format_part(first)?; for value in values { - self.writer.write_all(sep)?; + self.line_buffer.write_all(sep)?; self.format_part(value)?; } @@ -152,11 +251,30 @@ where } fn indent(&mut self, level: usize) -> io::Result<()> { - write!(self.writer, "{:width$}", "", width = level * self.indent) + write!( + self.line_buffer, + "{:width$}", + "", + width = level * self.indent + ) + } + + fn newline(&mut self) -> io::Result<()> { + self.line_buffer.push(b'\n'); + self.writer.write_all(&self.line_buffer)?; + + self.line_buffer.clear(); + Ok(()) + } + + pub(crate) fn flush(&mut self) -> io::Result<()> { + self.writer.write_all(&self.line_buffer)?; + self.line_buffer.clear(); + Ok(()) } } -impl<'heap, W, S, T> FormatPart for TextFormat +impl<'heap, W, S, T, A> FormatPart for TextFormat where W: io::Write, S: SourceLookup<'heap>, @@ -166,39 +284,39 @@ where if let Some(source) = source { self.format_part(source) } else { - write!(self.writer, "{{def@{value}}}") + write!(self.line_buffer, "{{def@{value}}}") } } } -impl FormatPart<&str> for TextFormat +impl FormatPart<&str> for TextFormat where W: io::Write, { fn format_part(&mut self, value: &str) -> io::Result<()> { - write!(self.writer, "{value}") + write!(self.line_buffer, "{value}") } } -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, { fn format_part(&mut self, value: Symbol<'heap>) -> io::Result<()> { - write!(self.writer, "{value}") + write!(self.line_buffer, "{value}") } } -impl FormatPart for TextFormat +impl FormatPart for TextFormat where W: io::Write, { fn format_part(&mut self, value: Local) -> io::Result<()> { - write!(self.writer, "{value}") + write!(self.line_buffer, "{value}") } } -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, { @@ -207,12 +325,12 @@ where for projection in projections { match projection.kind { - ProjectionKind::Field(index) => write!(self.writer, ".{index}")?, - ProjectionKind::FieldByName(symbol) => write!(self.writer, ".{symbol}")?, + ProjectionKind::Field(index) => write!(self.line_buffer, ".{index}")?, + ProjectionKind::FieldByName(symbol) => write!(self.line_buffer, ".{symbol}")?, ProjectionKind::Index(local) => { - write!(self.writer, "[")?; + write!(self.line_buffer, "[")?; self.format_part(local)?; - write!(self.writer, "]")?; + write!(self.line_buffer, "]")?; } } } @@ -221,26 +339,26 @@ where } } -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, S: SourceLookup<'heap>, { fn format_part(&mut self, value: Constant<'heap>) -> io::Result<()> { match value { - Constant::Int(int) => write!(self.writer, "{int}"), - Constant::Primitive(primitive) => write!(self.writer, "{primitive}"), - Constant::Unit => self.writer.write_all(b"()"), + Constant::Int(int) => write!(self.line_buffer, "{int}"), + Constant::Primitive(primitive) => write!(self.line_buffer, "{primitive}"), + Constant::Unit => self.line_buffer.write_all(b"()"), Constant::FnPtr(def) => { - self.writer.write_all(b"(")?; + self.line_buffer.write_all(b"(")?; self.format_part(def)?; - self.writer.write_all(b" as FnPtr)") + self.line_buffer.write_all(b" as FnPtr)") } } } } -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, S: SourceLookup<'heap>, @@ -253,7 +371,7 @@ where } } -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, { @@ -261,100 +379,114 @@ where let mut named_symbol = |name, id, binder: Option>| { if let Some(binder) = binder { if let Some(name) = binder.name { - return self.writer.write_all(name.as_bytes()); + return self.line_buffer.write_all(name.as_bytes()); } - return write!(self.writer, "{{{name}#{}}}", binder.id); + return write!(self.line_buffer, "{{{name}#{}}}", binder.id); } - write!(self.writer, "{{{name}@{id}}}") + write!(self.line_buffer, "{{{name}@{id}}}") }; match value { Source::Ctor(symbol) => { - write!(self.writer, "{{ctor#{symbol}}}") + write!(self.line_buffer, "{{ctor#{symbol}}}") } Source::Closure(id, binder) => named_symbol("closure", id, binder), Source::GraphReadFilter(id) => named_symbol("graph::read::filter", id, None), Source::Thunk(id, binder) => named_symbol("thunk", id, binder), Source::Intrinsic(def_id) => { - write!(self.writer, "{{intrinsic#{def_id}}}") + write!(self.line_buffer, "{{intrinsic#{def_id}}}") } } } } -impl<'heap, W, S, T> FormatPart<(BasicBlockId, &BasicBlock<'heap>)> for TextFormat +impl<'heap, W, S, T, A> FormatPart<(BasicBlockId, &BasicBlock<'heap>)> for TextFormat where W: io::Write, S: SourceLookup<'heap>, + A: TextFormatAnnotations, { fn format_part(&mut self, (id, block): (BasicBlockId, &BasicBlock<'heap>)) -> io::Result<()> { self.indent(1)?; - write!(self.writer, "{id}(")?; + write!(self.line_buffer, "{id}(")?; self.csv(block.params.iter().copied())?; - writeln!(self.writer, "): {{")?; + write!(self.line_buffer, "): {{")?; + self.newline()?; + + let mut location = Location { + block: id, + statement_index: 0, + }; for statement in &block.statements { - self.format_part(statement)?; - self.writer.write_all(b"\n")?; + location.statement_index += 1; + + self.format_part((location, statement))?; + self.newline()?; } if !block.statements.is_empty() { - self.writer.write_all(b"\n")?; + self.newline()?; } self.indent(2)?; self.format_part(&block.terminator)?; - self.writer.write_all(b"\n")?; + self.newline()?; self.indent(1)?; - writeln!(self.writer, "}}")?; + write!(self.line_buffer, "}}")?; + self.newline()?; Ok(()) } } /// A wrapper for formatting target parameters in MIR terminators. +/// +/// Renders the argument list as `(arg1, arg2, ...)` for goto and switch targets. pub(crate) struct TargetParams<'heap>(pub Interned<'heap, [Operand<'heap>]>); -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, S: SourceLookup<'heap>, { fn format_part(&mut self, TargetParams(args): TargetParams<'heap>) -> io::Result<()> { - write!(self.writer, "(")?; + write!(self.line_buffer, "(")?; self.csv(args.iter().copied())?; - write!(self.writer, ")") + write!(self.line_buffer, ")") } } -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, S: SourceLookup<'heap>, { fn format_part(&mut self, Target { block, args }: Target<'heap>) -> io::Result<()> { - write!(self.writer, "{block}")?; + write!(self.line_buffer, "{block}")?; self.format_part(TargetParams(args)) } } struct AnonymousTarget(BasicBlockId); -impl<'heap, W, S, T> FormatPart for TextFormat +impl<'heap, W, S, T, A> FormatPart for TextFormat where W: io::Write, S: SourceLookup<'heap>, { fn format_part(&mut self, AnonymousTarget(id): AnonymousTarget) -> io::Result<()> { - write!(self.writer, "{id}(_)") + write!(self.line_buffer, "{id}(_)") } } +/// Configuration options for formatting types in MIR output. pub(crate) struct TypeOptions { + /// The output format (plain text or HTML fragment). format: RenderFormat, } @@ -368,9 +500,10 @@ impl TypeOptions { } } +/// A wrapper for formatting a type with specific rendering options. pub(crate) struct Type(TypeId, TypeOptions); -impl<'fmt, 'env, 'heap: 'fmt + 'env, W, S, T> FormatPart for TextFormat +impl<'fmt, 'env, 'heap: 'fmt + 'env, W, S, T, A> FormatPart for TextFormat where W: io::Write, T: AsMut>, @@ -378,21 +511,21 @@ where fn format_part(&mut self, Type(r#type, options): Type) -> io::Result<()> { self.types .as_mut() - .render_into(r#type, options.render(), &mut self.writer) + .render_into(r#type, options.render(), &mut self.line_buffer) } } -impl<'body, 'fmt, 'env, 'heap: 'fmt + 'env, W, S, T> FormatPart> - for TextFormat +impl<'body, 'fmt, 'env, 'heap: 'fmt + 'env, W, S, T, A> FormatPart> + for TextFormat where W: io::Write, T: AsMut>, { fn format_part(&mut self, Signature(body, options): Signature<'body, 'heap>) -> io::Result<()> { - write!(self.writer, "{} ", source_keyword(body.source))?; + write!(self.line_buffer, "{} ", source_keyword(body.source))?; self.format_part(body.source)?; - self.writer.write_all(b"(")?; + self.line_buffer.write_all(b"(")?; self.csv((0..body.args).map(Local::new).map(|local| { let decl = body.local_decls[local]; @@ -406,7 +539,7 @@ where ), ) }))?; - self.writer.write_all(b") -> ")?; + self.line_buffer.write_all(b") -> ")?; self.format_part(Type( body.return_type, TypeOptions { @@ -418,7 +551,7 @@ where } } -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, S: SourceLookup<'heap>, @@ -426,15 +559,15 @@ where fn format_part(&mut self, value: GraphReadHead<'heap>) -> io::Result<()> { match value { GraphReadHead::Entity { axis } => { - self.writer.write_all(b"entities(")?; + self.line_buffer.write_all(b"entities(")?; self.format_part(axis)?; - self.writer.write_all(b")") + self.line_buffer.write_all(b")") } } } } -impl<'heap, W, S, T> FormatPart for TextFormat +impl<'heap, W, S, T, A> FormatPart for TextFormat where W: io::Write, S: SourceLookup<'heap>, @@ -442,28 +575,28 @@ where fn format_part(&mut self, value: GraphReadBody) -> io::Result<()> { match value { GraphReadBody::Filter(def_id, local) => { - self.writer.write_all(b"filter(")?; + self.line_buffer.write_all(b"filter(")?; self.format_part(def_id)?; - self.writer.write_all(b", ")?; + self.line_buffer.write_all(b", ")?; self.format_part(local)?; - self.writer.write_all(b")") + self.line_buffer.write_all(b")") } } } } -impl FormatPart for TextFormat +impl FormatPart for TextFormat where W: io::Write, { fn format_part(&mut self, value: GraphReadTail) -> io::Result<()> { match value { - GraphReadTail::Collect => self.writer.write_all(b"collect"), + GraphReadTail::Collect => self.line_buffer.write_all(b"collect"), } } } -impl<'heap, W, S, T> FormatPart<&GraphRead<'heap>> for TextFormat +impl<'heap, W, S, T, A> FormatPart<&GraphRead<'heap>> for TextFormat where W: io::Write, S: SourceLookup<'heap>, @@ -477,19 +610,19 @@ where target: _, }: &GraphRead<'heap>, ) -> io::Result<()> { - self.writer.write_all(b"graph read ")?; + self.line_buffer.write_all(b"graph read ")?; self.format_part(*head)?; for &body in body { - self.writer.write_all(b"\n")?; + self.newline()?; self.indent(2)?; - self.writer.write_all(b"|> ")?; + self.line_buffer.write_all(b"|> ")?; self.format_part(body)?; } - self.writer.write_all(b"\n")?; + self.newline()?; self.indent(2)?; - self.writer.write_all(b"|> ")?; + self.line_buffer.write_all(b"|> ")?; self.format_part(*tail) } } @@ -500,42 +633,42 @@ pub(crate) struct TerminatorHead<'terminator, 'heap>(pub &'terminator Terminator /// A wrapper for formatting the tail (target and arguments) part of MIR terminators. pub(crate) struct TerminatorTail<'terminator, 'heap>(pub &'terminator TerminatorKind<'heap>); -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, S: SourceLookup<'heap>, { fn format_part(&mut self, TerminatorHead(kind): TerminatorHead<'_, 'heap>) -> io::Result<()> { match kind { - &TerminatorKind::Goto(Goto { target: _ }) => write!(self.writer, "goto"), + &TerminatorKind::Goto(Goto { target: _ }) => write!(self.line_buffer, "goto"), &TerminatorKind::SwitchInt(SwitchInt { discriminant, targets: _, }) => { - write!(self.writer, "switchInt(")?; + write!(self.line_buffer, "switchInt(")?; self.format_part(discriminant)?; - self.writer.write_all(b")") + self.line_buffer.write_all(b")") } &TerminatorKind::Return(Return { value }) => { - write!(self.writer, "return ")?; + write!(self.line_buffer, "return ")?; self.format_part(value) } TerminatorKind::GraphRead(graph_read) => self.format_part(graph_read), - TerminatorKind::Unreachable => write!(self.writer, "unreachable"), + TerminatorKind::Unreachable => write!(self.line_buffer, "unreachable"), } } } -impl FormatPart for TextFormat +impl FormatPart for TextFormat where W: io::Write, { fn format_part(&mut self, value: u128) -> io::Result<()> { - write!(self.writer, "{value}") + write!(self.line_buffer, "{value}") } } -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, S: SourceLookup<'heap>, @@ -543,24 +676,24 @@ where fn format_part(&mut self, TerminatorTail(kind): TerminatorTail<'_, 'heap>) -> io::Result<()> { match kind { &TerminatorKind::Goto(Goto { target }) => { - write!(self.writer, " -> ")?; + write!(self.line_buffer, " -> ")?; self.format_part(target) } TerminatorKind::SwitchInt(SwitchInt { discriminant: _, targets, }) => { - write!(self.writer, " -> [")?; + write!(self.line_buffer, " -> [")?; self.csv( targets .iter() .map(|(value, target)| KeyValuePair(value, target)), )?; if let Some(otherwise) = targets.otherwise() { - write!(self.writer, ", otherwise: ")?; + write!(self.line_buffer, ", otherwise: ")?; self.format_part(otherwise)?; } - write!(self.writer, "]") + write!(self.line_buffer, "]") } &TerminatorKind::Return(_) | TerminatorKind::Unreachable => Ok(()), TerminatorKind::GraphRead(GraphRead { @@ -569,14 +702,14 @@ where tail: _, target, }) => { - write!(self.writer, " -> ")?; + write!(self.line_buffer, " -> ")?; self.format_part(AnonymousTarget(*target)) } } } } -impl<'heap, W, S, T> FormatPart<&Terminator<'heap>> for TextFormat +impl<'heap, W, S, T, A> FormatPart<&Terminator<'heap>> for TextFormat where W: io::Write, S: SourceLookup<'heap>, @@ -587,42 +720,42 @@ where } } -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, S: SourceLookup<'heap>, { fn format_part(&mut self, Binary { op, left, right }: Binary<'heap>) -> io::Result<()> { self.format_part(left)?; - write!(self.writer, " {} ", op.as_str())?; + write!(self.line_buffer, " {} ", op.as_str())?; self.format_part(right) } } -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, S: SourceLookup<'heap>, { fn format_part(&mut self, Unary { op, operand }: Unary<'heap>) -> io::Result<()> { - write!(self.writer, "{}", op.as_str())?; + write!(self.line_buffer, "{}", op.as_str())?; self.format_part(operand) } } -impl FormatPart> for TextFormat +impl FormatPart> for TextFormat where W: io::Write, Self: FormatPart + FormatPart, { fn format_part(&mut self, KeyValuePair(key, value): KeyValuePair) -> io::Result<()> { self.format_part(key)?; - self.writer.write_all(b": ")?; + self.line_buffer.write_all(b": ")?; self.format_part(value) } } -impl<'heap, W, S, T> FormatPart<&Aggregate<'heap>> for TextFormat +impl<'heap, W, S, T, A> FormatPart<&Aggregate<'heap>> for TextFormat where W: io::Write, S: SourceLookup<'heap>, @@ -630,12 +763,12 @@ where fn format_part(&mut self, Aggregate { kind, operands }: &Aggregate<'heap>) -> io::Result<()> { match kind { AggregateKind::Tuple => { - self.writer.write_all(b"(")?; + self.line_buffer.write_all(b"(")?; self.csv(operands.iter().copied())?; - self.writer.write_all(b")") + self.line_buffer.write_all(b")") } AggregateKind::Struct { fields } => { - self.writer.write_all(b"(")?; + self.line_buffer.write_all(b"(")?; self.csv( fields @@ -644,15 +777,15 @@ where .map(|(&key, &value)| KeyValuePair(key, value)), )?; - self.writer.write_all(b")") + self.line_buffer.write_all(b")") } AggregateKind::List => { - self.writer.write_all(b"list(")?; + self.line_buffer.write_all(b"list(")?; self.csv(operands.iter().copied())?; - self.writer.write_all(b")") + self.line_buffer.write_all(b")") } AggregateKind::Dict => { - self.writer.write_all(b"dict(")?; + self.line_buffer.write_all(b"dict(")?; self.csv( operands @@ -662,45 +795,45 @@ where .map(|[key, value]| KeyValuePair(key, value)), )?; - self.writer.write_all(b")") + self.line_buffer.write_all(b")") } AggregateKind::Opaque(symbol) => { - self.writer.write_all(b"opaque(")?; - self.writer.write_all(symbol.as_bytes())?; - self.writer.write_all(b", ")?; + self.line_buffer.write_all(b"opaque(")?; + self.line_buffer.write_all(symbol.as_bytes())?; + self.line_buffer.write_all(b", ")?; self.csv(operands.iter().copied())?; - self.writer.write_all(b")") + self.line_buffer.write_all(b")") } AggregateKind::Closure => { - self.writer.write_all(b"closure(")?; + self.line_buffer.write_all(b"closure(")?; self.csv(operands.iter().copied())?; - self.writer.write_all(b")") + self.line_buffer.write_all(b")") } } } } -impl<'heap, W, S, T> FormatPart> for TextFormat +impl<'heap, W, S, T, A> FormatPart> for TextFormat where W: io::Write, { fn format_part(&mut self, Input { op, name }: Input<'heap>) -> io::Result<()> { - self.writer.write_all(b"input ")?; + self.line_buffer.write_all(b"input ")?; match op { InputOp::Load { required: _ } => { - self.writer.write_all(b"LOAD ")?; + self.line_buffer.write_all(b"LOAD ")?; } InputOp::Exists => { - self.writer.write_all(b"EXISTS ")?; + self.line_buffer.write_all(b"EXISTS ")?; } } - self.writer.write_all(name.as_bytes()) + self.line_buffer.write_all(name.as_bytes()) } } -impl<'heap, W, S, T> FormatPart<&Apply<'heap>> for TextFormat +impl<'heap, W, S, T, A> FormatPart<&Apply<'heap>> for TextFormat where W: io::Write, S: SourceLookup<'heap>, @@ -712,11 +845,11 @@ where arguments, }: &Apply<'heap>, ) -> io::Result<()> { - self.writer.write_all(b"apply ")?; + self.line_buffer.write_all(b"apply ")?; self.format_part(*function)?; for argument in arguments { - self.writer.write_all(b" ")?; + self.line_buffer.write_all(b" ")?; self.format_part(*argument)?; } @@ -724,7 +857,7 @@ where } } -impl<'heap, W, S, T> FormatPart<&RValue<'heap>> for TextFormat +impl<'heap, W, S, T, A> FormatPart<&RValue<'heap>> for TextFormat where W: io::Write, S: SourceLookup<'heap>, @@ -741,30 +874,47 @@ where } } -impl<'heap, W, S, T> FormatPart<&Statement<'heap>> for TextFormat +impl<'heap, W, S, T, A> FormatPart<(Location, &Statement<'heap>)> for TextFormat where W: io::Write, S: SourceLookup<'heap>, + A: TextFormatAnnotations, { - fn format_part(&mut self, Statement { span: _, kind }: &Statement<'heap>) -> io::Result<()> { + fn format_part( + &mut self, + (location, statement @ Statement { span: _, kind }): (Location, &Statement<'heap>), + ) -> io::Result<()> { self.indent(2)?; match kind { StatementKind::Assign(Assign { lhs, rhs }) => { self.format_part(*lhs)?; - self.writer.write_all(b" = ")?; - self.format_part(rhs) + self.line_buffer.write_all(b" = ")?; + self.format_part(rhs)?; } - StatementKind::Nop => self.writer.write_all(b"nop"), + StatementKind::Nop => self.line_buffer.write_all(b"nop")?, &StatementKind::StorageLive(local) => { - self.writer.write_all(b"let ")?; - self.format_part(local) + self.line_buffer.write_all(b"let ")?; + self.format_part(local)?; } &StatementKind::StorageDead(local) => { - self.writer.write_all(b"drop ")?; - self.format_part(local) + self.line_buffer.write_all(b"drop ")?; + self.format_part(local)?; } } + + let Some(annotation) = self.annotations.annotate_statement(location, statement) else { + return Ok(()); + }; + + // We estimate that we never exceed 80 columns, calculate the remaining width, if we don't + // have enough space, we add 4 spaces breathing room. + let remaining_width = 80_usize.checked_sub(self.line_buffer.len()).unwrap_or(4); + self.line_buffer + .resize(self.line_buffer.len() + remaining_width, b' '); + write!(self.line_buffer, "// {annotation}")?; + + Ok(()) } } @@ -772,19 +922,20 @@ struct BodyRenderOptions { highlight: bool, } -impl<'fmt, 'env, 'heap: 'fmt + 'env, W, S, T> FormatPart<(&Body<'heap>, BodyRenderOptions)> - for TextFormat +impl<'fmt, 'env, 'heap: 'fmt + 'env, W, S, T, A> FormatPart<(&Body<'heap>, BodyRenderOptions)> + for TextFormat where W: io::Write, S: SourceLookup<'heap>, T: AsMut>, + A: TextFormatAnnotations, { fn format_part( &mut self, (body, options): (&Body<'heap>, BodyRenderOptions), ) -> io::Result<()> { if options.highlight { - self.writer.write_all(b"*")?; + self.line_buffer.write_all(b"*")?; } self.format_part(Signature( @@ -793,44 +944,56 @@ where format: RenderFormat::Plain, }, ))?; - self.writer.write_all(b" {\n")?; + self.line_buffer.write_all(b" {")?; + self.newline()?; // Do not render locals that are arguments, as they are already rendered in the signature for (local, decl) in body.local_decls.iter_enumerated().skip(body.args) { self.indent(1)?; - write!(self.writer, "let {local}: ")?; + write!(self.line_buffer, "let {local}: ")?; self.format_part(Type( decl.r#type, TypeOptions { format: RenderFormat::Plain, }, ))?; - self.writer.write_all(b"\n")?; + + if let Some(annotation) = self.annotations.annotate_local_decl(local, decl) { + // We estimate that we never exceed 80 columns, calculate the remaining width, if we + // don't have enough space, we add 4 spaces breathing room. + let remaining_width = 80_usize.checked_sub(self.line_buffer.len()).unwrap_or(4); + self.line_buffer + .resize(self.line_buffer.len() + remaining_width, b' '); + write!(self.line_buffer, "// {annotation}")?; + } + + self.newline()?; } if body.local_decls.len() > body.args { - self.writer.write_all(b"\n")?; + self.newline()?; } for (index, block) in body.basic_blocks.iter_enumerated() { if index.as_usize() > 0 { - self.writer.write_all(b"\n")?; + self.newline()?; } self.format_part((index, block))?; } - self.writer.write_all(b"}")?; + self.line_buffer.write_all(b"}")?; Ok(()) } } -impl<'fmt, 'env, 'heap: 'fmt + 'env, W, S, T> - FormatPart<(&DefIdSlice>, HighlightBody<'_>)> for TextFormat +impl<'fmt, 'env, 'heap: 'fmt + 'env, W, S, T, A> + FormatPart<(&DefIdSlice>, HighlightBody<'_>)> for TextFormat where W: io::Write, S: SourceLookup<'heap>, T: AsMut>, + A: TextFormatAnnotations, { fn format_part( &mut self, diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/all_args_excluded.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/all_args_excluded.snap new file mode 100644 index 00000000000..283ab6f3e9f --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/all_args_excluded.snap @@ -0,0 +1,19 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (Integer,), %1: Entity) -> Boolean { + let %2: Integer + let %3: Boolean + + bb0(): { + %2 = %0.0 + %3 = %2 == 42 + + return %3 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/non_vectors_entity_projection_rejected.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/non_vectors_entity_projection_rejected.snap new file mode 100644 index 00000000000..981a34bc5e7 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/non_vectors_entity_projection_rejected.snap @@ -0,0 +1,17 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { + let %2: Boolean + + bb0(): { + %2 = %1.metadata.archived + + return %2 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/only_vectors_projection_supported.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/only_vectors_projection_supported.snap new file mode 100644 index 00000000000..b9e86c3d371 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/only_vectors_projection_supported.snap @@ -0,0 +1,18 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> ? { + let %2: ? + + bb0(): { + %2 = %1.encodings.vectors // cost: 4 + + return %2 + } +} + +=================== Traversals =================== + +Traversals: + %2: 4 diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/other_operations_rejected.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/other_operations_rejected.snap new file mode 100644 index 00000000000..14d5b00838e --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/other_operations_rejected.snap @@ -0,0 +1,35 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (Integer,), %1: Entity) -> Boolean { + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Integer + let %6: (Integer, Integer) + let %7: Integer + let %8: (Integer,) + let %9: (Integer) -> Integer + let %10: Integer + let %11: Boolean + + bb0(): { + %2 = 10 + %3 = 20 + %4 = %2 + %3 + %5 = -%4 + %6 = (1, 2) + %7 = input LOAD param + %8 = %0 + %9 = closure(({def@123} as FnPtr), %8) + %10 = apply %9 1 + %11 = 1 + + return %11 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/storage_statements_zero_cost.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/storage_statements_zero_cost.snap new file mode 100644 index 00000000000..eef17c67fe6 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/storage_statements_zero_cost.snap @@ -0,0 +1,20 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> ? { + let %2: ? + + bb0(): { + let %2 // cost: 0 + %2 = %1.encodings.vectors // cost: 4 + drop %2 // cost: 0 + + return %2 + } +} + +=================== Traversals =================== + +Traversals: + %2: 4 diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/all_statements_supported.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/all_statements_supported.snap new file mode 100644 index 00000000000..07b1dd10829 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/all_statements_supported.snap @@ -0,0 +1,37 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (Integer,), %1: Entity) -> Boolean { + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Integer + let %6: (Integer, Integer) + let %7: (a: Integer, b: Integer) + let %8: (Integer,) + let %9: (Integer) -> Integer + let %10: Integer + let %11: Integer + let %12: Boolean + + bb0(): { + %2 = 10 // cost: 8 + %3 = 20 // cost: 8 + %4 = %2 + %3 // cost: 8 + %5 = -%4 // cost: 8 + %6 = (1, 2) // cost: 8 + %7 = (a: 3, b: 4) // cost: 8 + %8 = %0 // cost: 8 + %9 = closure(({def@42} as FnPtr), %8) // cost: 8 + %10 = apply %9 5 // cost: 8 + %11 = input LOAD param // cost: 8 + %12 = 1 // cost: 8 + + return %12 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/storage_statements_zero_cost.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/storage_statements_zero_cost.snap new file mode 100644 index 00000000000..826771d4f33 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/storage_statements_zero_cost.snap @@ -0,0 +1,25 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Integer { + let %2: Integer + let %3: Integer + let %4: Integer + + bb0(): { + let %2 // cost: 0 + %2 = 10 // cost: 8 + let %3 // cost: 0 + %3 = 20 // cost: 8 + %4 = %2 + %3 // cost: 8 + drop %2 // cost: 0 + drop %3 // cost: 0 + + return %4 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_closure_rejected.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_closure_rejected.snap new file mode 100644 index 00000000000..d32fd93f7ff --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_closure_rejected.snap @@ -0,0 +1,21 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: Integer, %1: Entity) -> Boolean { + let %2: Integer + let %3: (Integer) -> Integer + let %4: Boolean + + bb0(): { + %2 = %0 // cost: 4 + %3 = closure(({def@42} as FnPtr), %2) + %4 = 1 // cost: 4 + + return %4 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_tuple_supported.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_tuple_supported.snap new file mode 100644 index 00000000000..16347008ee5 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_tuple_supported.snap @@ -0,0 +1,21 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { + let %2: (Integer, Integer) + let %3: (a: Integer, b: Integer) + let %4: Boolean + + bb0(): { + %2 = (1, 2) // cost: 4 + %3 = (a: 10, b: 20) // cost: 4 + %4 = 1 // cost: 4 + + return %4 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/apply_rejected.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/apply_rejected.snap new file mode 100644 index 00000000000..4d9113d12f9 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/apply_rejected.snap @@ -0,0 +1,23 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: Integer, %1: Entity) -> Boolean { + let %2: Integer + let %3: (Integer) -> Integer + let %4: Integer + let %5: Boolean + + bb0(): { + %2 = %0 // cost: 4 + %3 = closure(({def@99} as FnPtr), %2) + %4 = apply %3 42 + %5 = %4 == 0 + + return %5 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/binary_unary_ops_supported.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/binary_unary_ops_supported.snap new file mode 100644 index 00000000000..b0ad1ca6f4f --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/binary_unary_ops_supported.snap @@ -0,0 +1,25 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Boolean + let %6: Boolean + + bb0(): { + %2 = 10 // cost: 4 + %3 = 20 // cost: 4 + %4 = %2 + %3 // cost: 4 + %5 = %4 > 15 // cost: 4 + %6 = !%5 // cost: 4 + + return %6 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/diamond_must_analysis.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/diamond_must_analysis.snap new file mode 100644 index 00000000000..22352c1bc1d --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/diamond_must_analysis.snap @@ -0,0 +1,41 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (Integer,), %1: Entity) -> Boolean { + let %2: Boolean + let %3: (Integer,) + let %4: (Integer) -> Integer + let %5: Integer + let %6: Boolean + + bb0(): { + %2 = 1 // cost: 4 + %3 = %0 // cost: 4 + %4 = closure(({def@77} as FnPtr), %3) + + switchInt(%2) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + %5 = 42 // cost: 4 + + goto -> bb3(%5) + } + + bb2(): { + %5 = apply %4 1 + + goto -> bb3(%5) + } + + bb3(%5): { + %6 = %5 == 0 + + return %6 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_column.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_column.snap new file mode 100644 index 00000000000..1a61f1151e5 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_column.snap @@ -0,0 +1,18 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { + let %2: Boolean + + bb0(): { + %2 = %1.metadata.archived // cost: 4 + + return %2 + } +} + +=================== Traversals =================== + +Traversals: + %2: 4 diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_jsonb.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_jsonb.snap new file mode 100644 index 00000000000..84f716750c5 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_jsonb.snap @@ -0,0 +1,18 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> ? { + let %2: ? + + bb0(): { + %2 = %1.properties // cost: 4 + + return %2 + } +} + +=================== Traversals =================== + +Traversals: + %2: 4 diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_with_closure_type_rejected.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_with_closure_type_rejected.snap new file mode 100644 index 00000000000..06c6d7d008c --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_with_closure_type_rejected.snap @@ -0,0 +1,19 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (Integer, (Integer) -> Integer), %1: Entity) -> Boolean { + let %2: Integer + let %3: Boolean + + bb0(): { + %2 = %0.0 + %3 = %2 == 42 + + return %3 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_without_closure_accepted.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_without_closure_accepted.snap new file mode 100644 index 00000000000..c21e2e65bbd --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_without_closure_accepted.snap @@ -0,0 +1,19 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (Integer, Boolean), %1: Entity) -> Boolean { + let %2: Integer + let %3: Boolean + + bb0(): { + %2 = %0.0 // cost: 4 + %3 = %2 == 42 // cost: 4 + + return %3 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/graph_read_edge_unsupported.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/graph_read_edge_unsupported.snap new file mode 100644 index 00000000000..f54367c9c82 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/graph_read_edge_unsupported.snap @@ -0,0 +1,30 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Integer + let %6: Boolean + + bb0(): { + %2 = 10 // cost: 4 + %4 = 10 // cost: 4 + + graph read entities(%2) + |> collect -> bb1(_) + } + + bb1(%3): { + %5 = %3 + %4 + %6 = %5 > 0 + + return %6 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/input_supported.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/input_supported.snap new file mode 100644 index 00000000000..afdd06acdc4 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/input_supported.snap @@ -0,0 +1,19 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { + let %2: Integer + let %3: Boolean + + bb0(): { + %2 = input LOAD threshold // cost: 4 + %3 = %2 > 100 // cost: 4 + + return %3 + } +} + +=================== Traversals =================== + +Traversals: diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/storage_statements_zero_cost.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/storage_statements_zero_cost.snap new file mode 100644 index 00000000000..1eb59985945 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/storage_statements_zero_cost.snap @@ -0,0 +1,25 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/execution/statement_placement/tests.rs +expression: output +--- +fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Integer { + let %2: Integer + let %3: Integer + let %4: Integer + + bb0(): { + let %2 // cost: 0 + %2 = 10 // cost: 4 + let %3 // cost: 0 + %3 = 20 // cost: 4 + %4 = %2 + %3 // cost: 4 + drop %2 // cost: 0 + drop %3 // cost: 0 + + return %4 + } +} + +=================== Traversals =================== + +Traversals: