From 6cbd4779850f7c011e17bd8eb4fba9d480612993 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 17 Dec 2025 13:51:41 -0800 Subject: [PATCH 1/3] Infrastructure for auto-tuning --- examples/xegpu_matmul/matmul.py | 50 +++++--- examples/xegpu_matmul/schedule.py | 202 ++++++++++++++++++++++++------ lighthouse/tune/__init__.py | 21 ++++ lighthouse/tune/__main__.py | 71 +++++++++++ lighthouse/tune/annotate.py | 90 +++++++++++++ lighthouse/tune/rewrite.py | 120 ++++++++++++++++++ lighthouse/tune/smt/__init__.py | 21 ++++ lighthouse/tune/smt/z3.py | 187 +++++++++++++++++++++++++++ lighthouse/utils/types.py | 23 ++++ pyproject.toml | 3 + 10 files changed, 732 insertions(+), 56 deletions(-) create mode 100644 lighthouse/tune/__init__.py create mode 100644 lighthouse/tune/__main__.py create mode 100644 lighthouse/tune/annotate.py create mode 100644 lighthouse/tune/rewrite.py create mode 100644 lighthouse/tune/smt/__init__.py create mode 100644 lighthouse/tune/smt/z3.py create mode 100644 lighthouse/utils/types.py diff --git a/examples/xegpu_matmul/matmul.py b/examples/xegpu_matmul/matmul.py index d86478e..bf8a947 100644 --- a/examples/xegpu_matmul/matmul.py +++ b/examples/xegpu_matmul/matmul.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# RUN: %PYTHON %s --dump-payload=xegpu-wg | FileCheck %s # CHECK: module attributes {gpu.container_module} { """ @@ -315,7 +315,7 @@ def parse_cli(): help="Check the result of the matrix multiplication.", ) parser.add_argument( - "--dump-kernel", + "--dump-payload", type=str, choices=[ "initial", @@ -328,13 +328,18 @@ def parse_cli(): "xegpu-inst", "final", ], - help="Dump kernel IR at different stages of lowering.", + help="Dump payload IR at different stages of lowering.", ) parser.add_argument( "--dump-schedule", action="store_true", help="Dump transform schedule.", ) + parser.add_argument( + "--non-det", + action="store_true", + help="Generate schedule with knob values left non-determined.", + ) args = parser.parse_args() return args @@ -344,21 +349,23 @@ def parse_cli(): args = parse_cli() params = { - "auto_wg_d0": args.wg_tile[0], - "auto_wg_d1": args.wg_tile[1], - "auto_sg_d0": args.sg_tile[0], - "auto_sg_d1": args.sg_tile[1], - "auto_k": args.k_tile, - "auto_load_a_d0": args.load_tile_a[0], - "auto_load_a_d1": args.load_tile_a[1], - "auto_load_b_d0": args.load_tile_b[0], - "auto_load_b_d1": args.load_tile_b[1], - "auto_prefetch_a_d0": args.prefetch_tile_a[0], - "auto_prefetch_a_d1": args.prefetch_tile_a[1], - "auto_prefetch_b_d0": args.prefetch_tile_b[0], - "auto_prefetch_b_d1": args.prefetch_tile_b[1], - "auto_nb_prefetch": args.nb_prefetch, + "wg_d0": args.wg_tile[0], + "wg_d1": args.wg_tile[1], + "sg_d0": args.sg_tile[0], + "sg_d1": args.sg_tile[1], + "k_tile": args.k_tile, + "load_a_d0": args.load_tile_a[0], + "load_a_d1": args.load_tile_a[1], + "load_b_d0": args.load_tile_b[0], + "load_b_d1": args.load_tile_b[1], + "prefetch_a_d0": args.prefetch_tile_a[0], + "prefetch_a_d1": args.prefetch_tile_a[1], + "prefetch_b_d0": args.prefetch_tile_b[0], + "prefetch_b_d1": args.prefetch_tile_b[1], + "nb_prefetch": args.nb_prefetch, } + if args.non_det: + params = {} M, N, K = args.sizes ab_type = "f16" @@ -375,9 +382,14 @@ def parse_cli(): has_relu=args.relu, ) - if args.dump_kernel or args.dump_schedule: + if args.dump_schedule: + schedule_module = wload.schedule_module( + stop_at_stage=args.dump_payload, parameters=params + ) + print(schedule_module) + elif args.dump_kernel: wload.lower_payload( - dump_payload=args.dump_kernel, + dump_payload=args.dump_payload, dump_schedule=args.dump_schedule, schedule_parameters=params, ) diff --git a/examples/xegpu_matmul/schedule.py b/examples/xegpu_matmul/schedule.py index 5a7133b..0ec55fb 100644 --- a/examples/xegpu_matmul/schedule.py +++ b/examples/xegpu_matmul/schedule.py @@ -1,16 +1,23 @@ +import inspect +from typing import Optional, Annotated + from mlir import ir from mlir.dialects.transform import loop from mlir.dialects.transform import bufferization from mlir.dialects.transform import xegpu from mlir.dialects.bufferization import LayoutMapOption -from mlir.dialects import transform -from mlir.dialects.transform import structured -from lighthouse.utils.mlir import ( - apply_registered_pass, - canonicalize, - match, +from mlir.dialects import transform, smt +from mlir.dialects.transform import ( + structured, + tune as transform_tune, + smt as transform_smt, +) +from lighthouse.utils.mlir import apply_registered_pass, canonicalize, match +from lighthouse.tune.annotate import ( + check_annotated_constraints, + NonDet, + ConstraintCollector, ) -from typing import Optional class PipelineInterrupt(Exception): @@ -76,7 +83,7 @@ def xegpu_matmul_transform_schedule( has_bias=has_bias, has_relu=has_relu, stop_at_stage=stop_at_stage, - params=params, + **params, ) mod = bundle_xegpu_to_binary( @@ -89,45 +96,166 @@ def xegpu_matmul_transform_schedule( transform.yield_() +@check_annotated_constraints def bundle_xepu_matmul_schedule( mod, has_bias: bool = False, has_relu: bool = False, stop_at_stage: str = "", - params: Optional[dict] = None, + *, + wg_d0: Annotated[int, lambda _: 128 <= _ <= 256 and _ % 32 == 0] = NonDet, + wg_d1: Annotated[int, lambda _: 128 <= _ <= 256 and _ % 32 == 0] = NonDet, + sg_d0: Annotated[int, lambda _: 16 <= _ <= 32 and _ % 8 == 0] = NonDet, + sg_d1: Annotated[int, lambda _: 16 <= _ <= 32 and _ % 8 == 0] = NonDet, + k_tile: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet, + load_a_d0: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet, + load_a_d1: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet, + load_b_d0: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet, + load_b_d1: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet, + prefetch_a_d0: Annotated[int, lambda _: 4 <= _ <= 8] = NonDet, + prefetch_a_d1: Annotated[int, lambda _: 16 <= _ <= 32] = NonDet, + prefetch_b_d0: Annotated[int, lambda _: 4 <= _ <= 8] = NonDet, + prefetch_b_d1: Annotated[int, lambda _: 8 <= _ <= 16] = NonDet, + nb_prefetch: Annotated[int, lambda _: 1 <= _ <= 32] = NonDet, + **_kwargs: Optional[dict], ) -> ir.Module: """Schedule for lowering matmul-like payload to xegpu wg level.""" - if params is None: - raise ValueError("Schedule parameters must be provided.") - - # tunable parameters - wg_tile = [params["auto_wg_d0"], params["auto_wg_d1"]] - sg_tile = [params["auto_sg_d0"], params["auto_sg_d1"]] - k_tile = params["auto_k"] - - load_tile_a = [params["auto_load_a_d0"], params["auto_load_a_d1"]] - load_tile_b = [params["auto_load_b_d0"], params["auto_load_b_d1"]] - - prefetch_tile_a = [params["auto_prefetch_a_d0"], params["auto_prefetch_a_d1"]] - prefetch_tile_b = [params["auto_prefetch_b_d0"], params["auto_prefetch_b_d1"]] - nb_prefetch = params["auto_nb_prefetch"] - - # derived parameters - sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]] - # number of threads collapsed to 1d layout - nb_threads = sg_layout[0] * sg_layout[1] * nb_workitems - prefetch_layout_a = [ - wg_tile[0] // prefetch_tile_a[0], - k_tile // prefetch_tile_a[1], - ] - prefetch_layout_b = [ - k_tile // prefetch_tile_b[0], - wg_tile[1] // prefetch_tile_b[1], + + sig = inspect.signature(bundle_xepu_matmul_schedule) + + any_param = transform.AnyParamType.get() + + use_knobs = NonDet in [ + wg_d0, + wg_d1, + prefetch_a_d0, + prefetch_a_d1, + prefetch_b_d0, + prefetch_b_d1, + k_tile, + load_a_d0, + load_a_d1, + load_b_d0, + load_b_d1, + prefetch_a_d0, + prefetch_a_d1, + prefetch_b_d0, + prefetch_b_d1, + nb_prefetch, ] + def as_const_or_as_knob(value, knob_name): + collector = ConstraintCollector() + sig.parameters[knob_name].annotation.__metadata__[0](collector) + if use_knobs: + return transform_tune.knob( + any_param, + name=knob_name, + options=collector.to_mlir(), + selected=value if value is not NonDet else None, + ) + return value + + wg_d0 = as_const_or_as_knob(wg_d0, "wg_d0") + wg_d1 = as_const_or_as_knob(wg_d1, "wg_d1") + wg_tile = [wg_d0, wg_d1] + sg_d0 = as_const_or_as_knob(sg_d0, "sg_d0") + sg_d1 = as_const_or_as_knob(sg_d1, "sg_d1") + sg_tile = [sg_d0, sg_d1] + + smt_int = smt.IntType.get() + c0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0) + c_nb_workitems = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), nb_workitems) + + if use_knobs: + constraint1 = transform_smt.constrain_params( + (any_param, any_param, any_param), + ( + wg_d0, + wg_d1, + sg_d0, + sg_d1, + ), + [smt_int] * 4, + ) + with ir.InsertionPoint(constraint1.body): + WGd0, WGd1, SGd0, SGd1 = constraint1.body.arguments + C0 = smt.int_constant(c0) + smt.assert_(smt.eq((smt.int_mod(WGd0, SGd0), C0))) + smt.assert_(smt.eq((smt.int_mod(WGd1, SGd1), C0))) + d0_step_smt = smt.int_div(WGd0, SGd0) + d1_step_smt = smt.int_div(WGd1, SGd1) + nb_threads_smt = smt.int_mul( + (d0_step_smt, d1_step_smt, smt.int_constant(c_nb_workitems)) + ) + smt.yield_((d0_step_smt, d1_step_smt, nb_threads_smt)) + d0_step, d1_step, nb_threads = constraint1.results + sg_layout = [d0_step, d1_step] + else: + # derived parameters + sg_layout = [wg_d0 // sg_d0, wg_d1 // sg_d1] + # number of threads collapsed to 1d layout + nb_threads = sg_layout[0] * sg_layout[1] * nb_workitems + + prefetch_a_d0 = as_const_or_as_knob(prefetch_a_d0, "prefetch_a_d0") + prefetch_a_d1 = as_const_or_as_knob(prefetch_a_d1, "prefetch_a_d1") + prefetch_tile_a = [prefetch_a_d0, prefetch_a_d1] + prefetch_b_d0 = as_const_or_as_knob(prefetch_b_d0, "prefetch_b_d0") + prefetch_b_d1 = as_const_or_as_knob(prefetch_b_d1, "prefetch_b_d1") + prefetch_tile_b = [prefetch_b_d0, prefetch_b_d1] + k_tile = as_const_or_as_knob(k_tile, "k_tile") + + if use_knobs: + constraint2 = transform_smt.constrain_params( + (any_param, any_param, any_param, any_param), + ( + wg_d0, + wg_d1, + k_tile, + prefetch_a_d0, + prefetch_a_d1, + prefetch_b_d0, + prefetch_b_d1, + ), + [smt_int] * 7, + ) + with ir.InsertionPoint(constraint2.body): + WGd0, WGd1, K, PFAd0, PFAd1, PFBd0, PFBd1 = constraint2.body.arguments + C0 = smt.int_constant(c0) + smt.assert_(smt.eq((smt.int_mod(WGd0, PFAd0), C0))) + smt.assert_(smt.eq((smt.int_mod(K, PFAd1), C0))) + PFAd0_step = smt.int_div(WGd0, PFAd0) + PFAd1_step = smt.int_div(K, PFAd1) + + smt.assert_(smt.eq((smt.int_mod(K, PFBd0), C0))) + smt.assert_(smt.eq((smt.int_mod(WGd1, PFBd1), C0))) + PFBd0_step = smt.int_div(K, PFBd0) + PFBd1_step = smt.int_div(WGd1, PFBd1) + + smt.yield_((PFAd0_step, PFAd1_step, PFBd0_step, PFBd1_step)) + prefetch_layout_a = constraint2.results[0:2] + prefetch_layout_b = constraint2.results[2:4] + else: + prefetch_layout_a = [ + wg_d0 // prefetch_a_d0, + k_tile // prefetch_a_d1, + ] + prefetch_layout_b = [ + k_tile // prefetch_b_d0, + wg_d1 // prefetch_b_d1, + ] + # matmul matrix shapes - sg_tile_a = [sg_tile[0], k_tile] - sg_tile_b = [k_tile, sg_tile[1]] + sg_tile_a = [sg_d0, k_tile] + sg_tile_b = [k_tile, sg_d1] + + load_a_d0 = as_const_or_as_knob(load_a_d0, "load_a_d0") + load_a_d1 = as_const_or_as_knob(load_a_d1, "load_a_d1") + load_b_d0 = as_const_or_as_knob(load_b_d0, "load_b_d0") + load_b_d1 = as_const_or_as_knob(load_b_d1, "load_b_d1") + + load_tile_a = [load_a_d0, load_a_d1] + load_tile_b = [load_b_d0, load_b_d1] if stop_at_stage == "initial": raise PipelineInterrupt() diff --git a/lighthouse/tune/__init__.py b/lighthouse/tune/__init__.py new file mode 100644 index 0000000..467f825 --- /dev/null +++ b/lighthouse/tune/__init__.py @@ -0,0 +1,21 @@ +__all__ = ["smt", "rewrite"] + +import sys +import importlib + + +def __getattr__(name): + """Enable lazy loading of submodules. + + Enables `import lighthouse.tune as lh_tune; lh_tune.` with + loading of (the submodule's heavy) depenendencies only upon being needed. + """ + + if name in __all__: + # Import the submodule and cache it on the current module. That is, + # upon the next access __getattr__ will not be called. + submodule = importlib.import_module("lighthouse.tune." + name) + lighthouse_tune_mod = sys.modules[__name__] + setattr(lighthouse_tune_mod, name, submodule) + return submodule + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/lighthouse/tune/__main__.py b/lighthouse/tune/__main__.py new file mode 100644 index 0000000..f8f2110 --- /dev/null +++ b/lighthouse/tune/__main__.py @@ -0,0 +1,71 @@ +import sys +import argparse +from pprint import pprint +from typing import Mapping + +import z3 + +from mlir import ir +import lighthouse.tune as lh_tune +from lighthouse.utils.types import LazyChainMap + +HEADER = "//" * 40 + "\n// {}\n" + "//" * 40 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("file", type=str, help="Path to the MLIR file to process") + parser.add_argument( + "-n", type=int, help="Number of determinized schedules to find", default=1 + ) + parser.add_argument( + "--print-smtlib", action="store_true", help="Print the constraints in SMT-LIB format" + ) + parser.add_argument( + "--print-model", action="store_true", help="Print the model from the SMT solver" + ) + parser.add_argument( + "--print-knobs-set", action="store_true", help="Print the schedule with knobs set" + ) + args = parser.parse_args() + + file = sys.stdin if args.file == "-" else open(args.file, "r") + with ir.Context() as ctx, ir.Location.unknown(): + module = ir.Module.parse(file.read()) + + z3_constraints, values_to_z3_vars = ( + lh_tune.smt.z3.transform_tune_and_smt_ops_to_z3_constraints( + module.operation + ) + ) + + solver = z3.Solver() + solver.add(z3_constraints) + + if args.print_smtlib: + print(HEADER.format("SMT-LIB constraints")) + print(solver.sexpr()) + + all_models = lh_tune.smt.z3.all_smt(solver, values_to_z3_vars.values()) + for i in range(args.n): + model = next(all_models) + if args.print_model: + print(HEADER.format(f"SMT Model #{i+1}")) + pprint(model) + + env: Mapping[ir.Value | ir.Operation, ir.Attribute] = LazyChainMap( + values_to_z3_vars, lambda var: lh_tune.smt.z3.model_to_mlir(model[var]) + ) + + mod_op = lh_tune.rewrite.set_selected(module.operation, env) + + mod_op, undo = lh_tune.rewrite.constraint_results_to_constants(mod_op, env) + + if args.print_knobs_set: + print(HEADER.format(f"Schedule #{i+1} with knobs set")) + print(mod_op) + + print(HEADER.format(f"Determinized schedule #{i+1}")) + print(lh_tune.rewrite.nondet_to_det(mod_op.clone())) + + undo() # Undo the introduction of constants for the results of constraints. diff --git a/lighthouse/tune/annotate.py b/lighthouse/tune/annotate.py new file mode 100644 index 0000000..39241e8 --- /dev/null +++ b/lighthouse/tune/annotate.py @@ -0,0 +1,90 @@ +import inspect +from dataclasses import dataclass, field +from typing import Any, Callable, get_args +from functools import wraps + +from mlir import ir + +NonDet = object() # Sentinel value for non-determinized parameters. + + +def check_annotated_constraints(f: Callable): + """Wrapper that runs __metadata__ constraints from Annotated types on function arguments.""" + + @wraps(f) + def wrapper(*args, **kwargs): + sig = inspect.signature(f) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + for name, value in bound_args.arguments.items(): + if value is NonDet: + continue + + param = sig.parameters[name] + + if param.kind != param.KEYWORD_ONLY or not hasattr( + param.annotation, "__metadata__" + ): + continue + + wrapped_type, constraint = get_args(param.annotation) + + if not isinstance(value, wrapped_type): + raise TypeError( + f"Argument {name} must be of type {param.annotation}, got {type(value)}" + ) + + if not constraint(value): + raise ValueError( + f"Constraint {constraint} not satisfied for argument {name} with value {value}" + ) + + return f(*args, **kwargs) + + return wrapper + + +@dataclass +class ConstraintCollector: + """Used on annotated constraint functions to reify the constraints as data.""" + + children: list[Any] = field(default_factory=list) + lb: int | None = None + ub: int | None = None + + def __mod__(self, other): + self.children.append(mod := ModConstraintCollector(modulus=other)) + return mod + + def __le__(self, other): + self.ub = other + return self + + def __ge__(self, other): + self.lb = other + return self + + def to_mlir(self) -> ir.Attribute: + dict_attrs = {} + i64 = ir.IntegerType.get_signless(64) + if self.lb is not None: + dict_attrs["lb"] = ir.IntegerAttr.get(i64, self.lb) + if self.ub is not None: + dict_attrs["ub"] = ir.IntegerAttr.get(i64, self.ub) + if self.children: + assert len(self.children) == 1 and isinstance( + self.children[0], ModConstraintCollector + ) + dict_attrs["step"] = ir.IntegerAttr.get(i64, self.children[0].modulus) + return ir.DictAttr.get(dict_attrs) + + +@dataclass +class ModConstraintCollector: + modulus: int + remainder: int | None = None + + def __eq__(self, other): + assert other == 0, "Only equality with zero is currently supported" + return self diff --git a/lighthouse/tune/rewrite.py b/lighthouse/tune/rewrite.py new file mode 100644 index 0000000..6401f38 --- /dev/null +++ b/lighthouse/tune/rewrite.py @@ -0,0 +1,120 @@ +from collections import OrderedDict + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import smt as transform_smt, tune as transform_tune + + +def set_selected(op: ir.Operation, env: dict[ir.Value | ir.Operation, object]): + def recurse(op: ir.Operation): + for region in op.regions: + for block in region.blocks: + for child in block: + set_selected(child, env) + + match type(op): + case transform_tune.KnobOp: + op.attributes["selected"] = env[op.result] + case transform_tune.AlternativesOp: + op.attributes["selected_region"] = env[op] + recurse(op) + case _: + recurse(op) + return op + + +# This is a hack >;( +def constraint_results_to_constants( + op: ir.Operation | ir.Module, + env: dict[ir.Value | ir.Operation, object], + undo_actions=None, +): + undo_actions = undo_actions if undo_actions is not None else [] + + def undo(): + for action in reversed(undo_actions): + action() + + match type(op): + case transform_smt.ConstrainParamsOp: + with ir.InsertionPoint.after(op): + orig_results_and_uses = OrderedDict( + (res, list((use.owner, use.operand_number) for use in res.uses)) + for res in op.results + ) + + for result in op.results: + val = transform.param_constant(result.type, env[result]) + + for use in result.uses: + use.owner.operands[use.operand_number] = val + + def undo_rewrite(): + for orig_res, orig_uses in orig_results_and_uses.items(): + for orig_owner, orig_operand_number in orig_uses: + param = orig_owner.operands[orig_operand_number].owner + assert isinstance(param, transform.ParamConstantOp) + orig_owner.operands[orig_operand_number] = orig_res + param.erase() + + undo_actions.append(undo_rewrite) + case _: + for region in op.regions: + for block in region.blocks: + for child in block: + constraint_results_to_constants(child, env, undo_actions) + + return op, undo + + +def nondet_to_det(op: ir.Operation, env: dict[ir.Value, ir.Value] = None): + env = env if env is not None else {} # TODO: nested scopes + + i64 = ir.IntegerType.get_signless(64) + transform_param_i64 = transform.ParamType.get(i64) + + match type(op): + case transform_tune.KnobOp: + assert "selected" in op.attributes + with ir.InsertionPoint.after(op): + subst = transform.param_constant( + transform_param_i64, op.attributes["selected"] + ) + + for use in op.result.uses: + use.owner.operands[use.operand_number] = subst + + op.erase() + + case transform_tune.AlternativesOp: + assert "selected_region" in op.attributes + region_idx = op.attributes["selected_region"].value + with ir.InsertionPoint.after(op): + for child in op.regions[region_idx].blocks[0]: + new_yield = cloned_child = child.clone() + for result, new_result in zip(child.results, cloned_child.results): + env[result] = new_result + for idx, operand in enumerate(cloned_child.operands): + if operand in env: + cloned_child.operands[idx] = env[operand] + nondet_to_det(cloned_child, env) + for yield_operand, result in zip(new_yield.operands, op.results): + for res_use in result.uses: + res_use.owner.operands[res_use.operand_number] = yield_operand + new_yield.erase() + + op.erase() + + case transform_smt.ConstrainParamsOp: + for res in op.results: + assert next(res.uses, None) is None + + op.erase() + + case _: + for region in op.regions: + for block in region.blocks: + for child in block: + nondet_to_det(child, env) + + return op diff --git a/lighthouse/tune/smt/__init__.py b/lighthouse/tune/smt/__init__.py new file mode 100644 index 0000000..75cc276 --- /dev/null +++ b/lighthouse/tune/smt/__init__.py @@ -0,0 +1,21 @@ +__all__ = ["z3"] + +import sys +import importlib + + +def __getattr__(name): + """Enable lazy loading of submodules. + + Enables `import lighthouse.tune as lh_tune; lh_tune.smt.` with + loading of (the submodule's heavy) depenendencies only upon being needed. + """ + + if name in __all__: + # Import the submodule and cache it on the current module. That is, + # upon the next access __getattr__ will not be called. + submodule = importlib.import_module("lighthouse.tune.smt." + name) + lighthouse_mod = sys.modules[__name__] + setattr(lighthouse_mod, name, submodule) + return submodule + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/lighthouse/tune/smt/z3.py b/lighthouse/tune/smt/z3.py new file mode 100644 index 0000000..aa5113f --- /dev/null +++ b/lighthouse/tune/smt/z3.py @@ -0,0 +1,187 @@ +import operator + +from functools import reduce + +from mlir import ir +from mlir.dialects import transform, smt +from mlir.dialects.transform import smt as transform_smt, tune as transform_tune + +import z3 + + +# From: http://theory.stanford.edu/%7Enikolaj/programmingz3.html#sec-blocking-evaluations +def all_smt(s, initial_terms): + def block_term(s, m, t): + s.add(t != m.eval(t)) + + def fix_term(s, m, t): + s.add(t == m.eval(t)) + + def all_smt_rec(terms): + if z3.sat == s.check(): + m = s.model() + yield m + for i in range(len(terms)): + s.push() + block_term(s, m, terms[i]) + for j in range(i): + fix_term(s, m, terms[j]) + yield from all_smt_rec(terms[i:]) + s.pop() + + yield from all_smt_rec(list(initial_terms)) + + +def model_to_mlir(x): + return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), x.py_value()) + + +name_counter = 0 + + +def fresh_name(prefix="fresh"): + global name_counter + name_counter += 1 + return f"{prefix}{name_counter}" + + +def transform_tune_and_smt_ops_to_z3_constraints( + op: ir.Operation | ir.OpView, env=None, path=None, constraints=None +): + env = env if env is not None else {} # TODO: nested scopes + path = path if path is not None else [] + constraints = constraints if constraints is not None else [] + + match type(op): + case transform.ParamConstantOp: + name = fresh_name("cst") + var = env[op.result] = z3.Int(name) + C = op.value.value + constraints += [z3.Implies(z3.And(*path), var == C)] + + case transform.MatchParamCmpIOp: + lvar = env[op.operands[0]] + rvar = env[op.operands[1]] + predicate_to_operator = { + transform.MatchCmpIPredicate.eq: operator.eq, + transform.MatchCmpIPredicate.le: operator.le, + transform.MatchCmpIPredicate.lt: operator.lt, + transform.MatchCmpIPredicate.ge: operator.ge, + transform.MatchCmpIPredicate.gt: operator.gt, + transform.MatchCmpIPredicate.ne: operator.ne, + } + constraints += [ + z3.Implies( + z3.And(*path), + predicate_to_operator[ + transform.MatchCmpIPredicate(op.predicate.value) + ](lvar, rvar), + ) + ] + + case transform_tune.KnobOp: + var = env[op.result] = z3.Int(op.name.value) + if isinstance(op.options, ir.ArrayAttr): + constraints += [ + z3.Implies( + z3.And(*path), z3.Or(*[var == opt.value for opt in op.options]) + ) + ] + elif isinstance(op.options, ir.DictAttr): + assert "lb" in op.options and "ub" in op.options + atoms = [op.options["lb"].value <= var, var <= op.options["ub"].value] + if "step" in op.options: + atoms += [var % op.options["step"].value == 0] + + constraints += [z3.Implies(z3.And(*path), z3.And(atoms))] + else: + assert False, "Unknown options attribute type" + + case transform_tune.AlternativesOp: + var = env[op] = z3.Int(op.name.value) + constraints += [ + z3.Implies(z3.And(path), z3.And(0 <= var, var < len(op.regions))) + ] + for idx, region in enumerate(op.regions): + for child in region.blocks[0]: + transform_tune_and_smt_ops_to_z3_constraints( + child, env, path + [var == idx], constraints + ) + for yield_operand, result in zip( + region.blocks[0].operations[-1].operands, op.results + ): + if isinstance(yield_operand.type, transform.ParamType): + env[result] = z3.Int(fresh_name("alt_res")) + constraints += [ + z3.Implies( + z3.And(*(path + [var == idx])), + env[result] == env[yield_operand], + ) + ] + + case transform_smt.ConstrainParamsOp: + mapping_constraints = [] + for operand, block_arg in zip(op.operands, op.body.arguments): + var = env[block_arg] = z3.Int(fresh_name("cp_bbarg")) + mapping_constraints += [var == env[operand]] + constraints += [z3.Implies(z3.And(*path), z3.And(*mapping_constraints))] + for idx in range(len(op.body.operations) - 1): + transform_tune_and_smt_ops_to_z3_constraints( + op.body.operations[idx], env, path, constraints + ) + assert isinstance(op.body.operations[-1], smt.YieldOp) + mapping_constraints = [] + for result, yield_arg in zip(op.results, op.body.operations[-1].operands): + var = env[result] = z3.Int(fresh_name("cp_bbres")) + mapping_constraints += [var == env[yield_arg]] + constraints += [z3.Implies(z3.And(*path), z3.And(*mapping_constraints))] + + case smt.IntAddOp: + var = env[op.result] = z3.Int(fresh_name("add")) + constraints += [ + z3.Implies( + z3.And(*path), var == sum(env[value] for value in op.operands) + ) + ] + + case smt.IntMulOp: + var = env[op.result] = z3.Int(fresh_name("mul")) + constraints += [ + z3.Implies( + z3.And(*path), + var == reduce(operator.mul, (env[value] for value in op.operands)), + ) + ] + + case smt.IntConstantOp: + var = env[op.result] = z3.Int(fresh_name("cst")) + constraints += [z3.Implies(z3.And(*path), var == op.value.value)] + + case smt.EqOp: + assert len(op.operands) == 2 + lhs, rhs = env[op.operands[0]], env[op.operands[1]] + var = env[op.result] = z3.Bool(fresh_name("eq")) + constraints += [z3.Implies(z3.And(*path), var == (lhs == rhs))] + + case smt.IntModOp: + lhs, rhs = env[op.lhs], env[op.rhs] + var = env[op.result] = z3.Int(fresh_name("int.mod")) + constraints += [z3.Implies(z3.And(*path), var == (lhs % rhs))] + + case smt.IntDivOp: + lhs, rhs = env[op.lhs], env[op.rhs] + var = env[op.result] = z3.Int(fresh_name("int.div")) + constraints += [z3.Implies(z3.And(*path), var == (lhs / rhs))] + + case smt.AssertOp: + constraints += [z3.Implies(z3.And(*path), env[op.input])] + + case _: + for region in op.regions: + for block in region.blocks: + for child in block: + transform_tune_and_smt_ops_to_z3_constraints( + child, env, path, constraints + ) + + return [z3.simplify(c) for c in constraints], env diff --git a/lighthouse/utils/types.py b/lighthouse/utils/types.py new file mode 100644 index 0000000..b0d87d9 --- /dev/null +++ b/lighthouse/utils/types.py @@ -0,0 +1,23 @@ +from typing import TypeVar, Generic, Callable + +from collections.abc import Mapping + +K = TypeVar("K") +V = TypeVar("V") +W = TypeVar("V") + +class LazyChainMap(Mapping, Generic[K, V, W]): + def __init__(self, data: dict[K, V], func: Callable[V, W]): + self._data = data + self._func = func + + def __getitem__(self, key): + # Access the underlying data and apply the transformation + value = self._data[key] + return self._func(value) + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) diff --git a/pyproject.toml b/pyproject.toml index dece5b8..2a0ed74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,9 @@ ingress_torch_xpu = [ "pytorch_triton_xpu", # Transitive dependency listed explicitly so that we can state which package repository it is supposed to come from "lighthouse[ingress_torch_mlir]" ] +tune_smt_z3 = [ + "z3-solver" +] [tool.uv] # Declare that the following "targets" are mutually exclusive of one another From d14486938d11bdc16780121981375f192c31970a Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 17 Dec 2025 14:08:05 -0800 Subject: [PATCH 2/3] Linted --- lighthouse/tune/__init__.py | 2 +- lighthouse/tune/__main__.py | 16 ++++++++++------ lighthouse/utils/types.py | 1 + 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/lighthouse/tune/__init__.py b/lighthouse/tune/__init__.py index 467f825..c5b3bf9 100644 --- a/lighthouse/tune/__init__.py +++ b/lighthouse/tune/__init__.py @@ -1,4 +1,4 @@ -__all__ = ["smt", "rewrite"] +__all__ = ["rewrite", "smt"] import sys import importlib diff --git a/lighthouse/tune/__main__.py b/lighthouse/tune/__main__.py index f8f2110..b9455fd 100644 --- a/lighthouse/tune/__main__.py +++ b/lighthouse/tune/__main__.py @@ -19,13 +19,17 @@ "-n", type=int, help="Number of determinized schedules to find", default=1 ) parser.add_argument( - "--print-smtlib", action="store_true", help="Print the constraints in SMT-LIB format" + "--print-smtlib", + action="store_true", + help="Print the constraints in SMT-LIB format", ) parser.add_argument( "--print-model", action="store_true", help="Print the model from the SMT solver" ) parser.add_argument( - "--print-knobs-set", action="store_true", help="Print the schedule with knobs set" + "--print-knobs-set", + action="store_true", + help="Print the schedule with knobs set", ) args = parser.parse_args() @@ -50,7 +54,7 @@ for i in range(args.n): model = next(all_models) if args.print_model: - print(HEADER.format(f"SMT Model #{i+1}")) + print(HEADER.format(f"SMT Model #{i + 1}")) pprint(model) env: Mapping[ir.Value | ir.Operation, ir.Attribute] = LazyChainMap( @@ -62,10 +66,10 @@ mod_op, undo = lh_tune.rewrite.constraint_results_to_constants(mod_op, env) if args.print_knobs_set: - print(HEADER.format(f"Schedule #{i+1} with knobs set")) + print(HEADER.format(f"Schedule #{i + 1} with knobs set")) print(mod_op) - print(HEADER.format(f"Determinized schedule #{i+1}")) + print(HEADER.format(f"Determinized schedule #{i + 1}")) print(lh_tune.rewrite.nondet_to_det(mod_op.clone())) - undo() # Undo the introduction of constants for the results of constraints. + undo() # Undo the introduction of constants for the results of constraints. diff --git a/lighthouse/utils/types.py b/lighthouse/utils/types.py index b0d87d9..8a92586 100644 --- a/lighthouse/utils/types.py +++ b/lighthouse/utils/types.py @@ -6,6 +6,7 @@ V = TypeVar("V") W = TypeVar("V") + class LazyChainMap(Mapping, Generic[K, V, W]): def __init__(self, data: dict[K, V], func: Callable[V, W]): self._data = data From 0aad2dec608c79defcbeb3c67019f4c75b83ed7f Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 17 Dec 2025 14:17:49 -0800 Subject: [PATCH 3/3] dump_kernel -> dump_payload --- examples/xegpu_matmul/matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/xegpu_matmul/matmul.py b/examples/xegpu_matmul/matmul.py index bf8a947..951b4fa 100644 --- a/examples/xegpu_matmul/matmul.py +++ b/examples/xegpu_matmul/matmul.py @@ -387,7 +387,7 @@ def parse_cli(): stop_at_stage=args.dump_payload, parameters=params ) print(schedule_module) - elif args.dump_kernel: + elif args.dump_payload: wload.lower_payload( dump_payload=args.dump_payload, dump_schedule=args.dump_schedule,