Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions examples/xegpu_matmul/matmul.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s
# RUN: %PYTHON %s --dump-payload=xegpu-wg | FileCheck %s
# CHECK: module attributes {gpu.container_module} {

"""
Expand Down Expand Up @@ -315,7 +315,7 @@ def parse_cli():
help="Check the result of the matrix multiplication.",
)
parser.add_argument(
"--dump-kernel",
"--dump-payload",
type=str,
choices=[
"initial",
Expand All @@ -328,13 +328,18 @@ def parse_cli():
"xegpu-inst",
"final",
],
help="Dump kernel IR at different stages of lowering.",
help="Dump payload IR at different stages of lowering.",
)
parser.add_argument(
"--dump-schedule",
action="store_true",
help="Dump transform schedule.",
)
parser.add_argument(
"--non-det",
action="store_true",
help="Generate schedule with knob values left non-determined.",
)
args = parser.parse_args()

return args
Expand All @@ -344,21 +349,23 @@ def parse_cli():
args = parse_cli()

params = {
"auto_wg_d0": args.wg_tile[0],
"auto_wg_d1": args.wg_tile[1],
"auto_sg_d0": args.sg_tile[0],
"auto_sg_d1": args.sg_tile[1],
"auto_k": args.k_tile,
"auto_load_a_d0": args.load_tile_a[0],
"auto_load_a_d1": args.load_tile_a[1],
"auto_load_b_d0": args.load_tile_b[0],
"auto_load_b_d1": args.load_tile_b[1],
"auto_prefetch_a_d0": args.prefetch_tile_a[0],
"auto_prefetch_a_d1": args.prefetch_tile_a[1],
"auto_prefetch_b_d0": args.prefetch_tile_b[0],
"auto_prefetch_b_d1": args.prefetch_tile_b[1],
"auto_nb_prefetch": args.nb_prefetch,
"wg_d0": args.wg_tile[0],
"wg_d1": args.wg_tile[1],
"sg_d0": args.sg_tile[0],
"sg_d1": args.sg_tile[1],
"k_tile": args.k_tile,
"load_a_d0": args.load_tile_a[0],
"load_a_d1": args.load_tile_a[1],
"load_b_d0": args.load_tile_b[0],
"load_b_d1": args.load_tile_b[1],
"prefetch_a_d0": args.prefetch_tile_a[0],
"prefetch_a_d1": args.prefetch_tile_a[1],
"prefetch_b_d0": args.prefetch_tile_b[0],
"prefetch_b_d1": args.prefetch_tile_b[1],
"nb_prefetch": args.nb_prefetch,
}
if args.non_det:
params = {}

M, N, K = args.sizes
ab_type = "f16"
Expand All @@ -375,9 +382,14 @@ def parse_cli():
has_relu=args.relu,
)

if args.dump_kernel or args.dump_schedule:
if args.dump_schedule:
schedule_module = wload.schedule_module(
stop_at_stage=args.dump_payload, parameters=params
)
print(schedule_module)
elif args.dump_payload:
Comment on lines +385 to +390
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good if we could still dump both the payload and schedule simultaneously if so desired.

wload.lower_payload(
dump_payload=args.dump_kernel,
dump_payload=args.dump_payload,
dump_schedule=args.dump_schedule,
schedule_parameters=params,
)
Expand Down
202 changes: 165 additions & 37 deletions examples/xegpu_matmul/schedule.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import inspect
from typing import Optional, Annotated

from mlir import ir
from mlir.dialects.transform import loop
from mlir.dialects.transform import bufferization
from mlir.dialects.transform import xegpu
from mlir.dialects.bufferization import LayoutMapOption
from mlir.dialects import transform
from mlir.dialects.transform import structured
from lighthouse.utils.mlir import (
apply_registered_pass,
canonicalize,
match,
from mlir.dialects import transform, smt
from mlir.dialects.transform import (
structured,
tune as transform_tune,
smt as transform_smt,
)
from lighthouse.utils.mlir import apply_registered_pass, canonicalize, match
from lighthouse.tune.annotate import (
check_annotated_constraints,
NonDet,
ConstraintCollector,
)
from typing import Optional


class PipelineInterrupt(Exception):
Expand Down Expand Up @@ -76,7 +83,7 @@ def xegpu_matmul_transform_schedule(
has_bias=has_bias,
has_relu=has_relu,
stop_at_stage=stop_at_stage,
params=params,
**params,
)

mod = bundle_xegpu_to_binary(
Expand All @@ -89,45 +96,166 @@ def xegpu_matmul_transform_schedule(
transform.yield_()


@check_annotated_constraints
def bundle_xepu_matmul_schedule(
mod,
has_bias: bool = False,
has_relu: bool = False,
stop_at_stage: str = "",
params: Optional[dict] = None,
*,
wg_d0: Annotated[int, lambda _: 128 <= _ <= 256 and _ % 32 == 0] = NonDet,
wg_d1: Annotated[int, lambda _: 128 <= _ <= 256 and _ % 32 == 0] = NonDet,
sg_d0: Annotated[int, lambda _: 16 <= _ <= 32 and _ % 8 == 0] = NonDet,
sg_d1: Annotated[int, lambda _: 16 <= _ <= 32 and _ % 8 == 0] = NonDet,
k_tile: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet,
load_a_d0: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet,
load_a_d1: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet,
load_b_d0: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet,
load_b_d1: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet,
prefetch_a_d0: Annotated[int, lambda _: 4 <= _ <= 8] = NonDet,
prefetch_a_d1: Annotated[int, lambda _: 16 <= _ <= 32] = NonDet,
prefetch_b_d0: Annotated[int, lambda _: 4 <= _ <= 8] = NonDet,
prefetch_b_d1: Annotated[int, lambda _: 8 <= _ <= 16] = NonDet,
nb_prefetch: Annotated[int, lambda _: 1 <= _ <= 32] = NonDet,
**_kwargs: Optional[dict],
) -> ir.Module:
"""Schedule for lowering matmul-like payload to xegpu wg level."""
if params is None:
raise ValueError("Schedule parameters must be provided.")

# tunable parameters
wg_tile = [params["auto_wg_d0"], params["auto_wg_d1"]]
sg_tile = [params["auto_sg_d0"], params["auto_sg_d1"]]
k_tile = params["auto_k"]

load_tile_a = [params["auto_load_a_d0"], params["auto_load_a_d1"]]
load_tile_b = [params["auto_load_b_d0"], params["auto_load_b_d1"]]

prefetch_tile_a = [params["auto_prefetch_a_d0"], params["auto_prefetch_a_d1"]]
prefetch_tile_b = [params["auto_prefetch_b_d0"], params["auto_prefetch_b_d1"]]
nb_prefetch = params["auto_nb_prefetch"]

# derived parameters
sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]]
# number of threads collapsed to 1d layout
nb_threads = sg_layout[0] * sg_layout[1] * nb_workitems
prefetch_layout_a = [
wg_tile[0] // prefetch_tile_a[0],
k_tile // prefetch_tile_a[1],
]
prefetch_layout_b = [
k_tile // prefetch_tile_b[0],
wg_tile[1] // prefetch_tile_b[1],

sig = inspect.signature(bundle_xepu_matmul_schedule)

any_param = transform.AnyParamType.get()

use_knobs = NonDet in [
wg_d0,
wg_d1,
prefetch_a_d0,
prefetch_a_d1,
prefetch_b_d0,
prefetch_b_d1,
k_tile,
load_a_d0,
load_a_d1,
load_b_d0,
load_b_d1,
prefetch_a_d0,
prefetch_a_d1,
prefetch_b_d0,
prefetch_b_d1,
nb_prefetch,
]

def as_const_or_as_knob(value, knob_name):
collector = ConstraintCollector()
sig.parameters[knob_name].annotation.__metadata__[0](collector)
if use_knobs:
return transform_tune.knob(
any_param,
name=knob_name,
options=collector.to_mlir(),
selected=value if value is not NonDet else None,
)
return value

wg_d0 = as_const_or_as_knob(wg_d0, "wg_d0")
wg_d1 = as_const_or_as_knob(wg_d1, "wg_d1")
wg_tile = [wg_d0, wg_d1]
sg_d0 = as_const_or_as_knob(sg_d0, "sg_d0")
sg_d1 = as_const_or_as_knob(sg_d1, "sg_d1")
sg_tile = [sg_d0, sg_d1]

smt_int = smt.IntType.get()
c0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0)
c_nb_workitems = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), nb_workitems)

if use_knobs:
constraint1 = transform_smt.constrain_params(
(any_param, any_param, any_param),
(
wg_d0,
wg_d1,
sg_d0,
sg_d1,
),
[smt_int] * 4,
)
with ir.InsertionPoint(constraint1.body):
WGd0, WGd1, SGd0, SGd1 = constraint1.body.arguments
C0 = smt.int_constant(c0)
smt.assert_(smt.eq((smt.int_mod(WGd0, SGd0), C0)))
smt.assert_(smt.eq((smt.int_mod(WGd1, SGd1), C0)))
d0_step_smt = smt.int_div(WGd0, SGd0)
d1_step_smt = smt.int_div(WGd1, SGd1)
nb_threads_smt = smt.int_mul(
(d0_step_smt, d1_step_smt, smt.int_constant(c_nb_workitems))
)
smt.yield_((d0_step_smt, d1_step_smt, nb_threads_smt))
d0_step, d1_step, nb_threads = constraint1.results
sg_layout = [d0_step, d1_step]
else:
# derived parameters
sg_layout = [wg_d0 // sg_d0, wg_d1 // sg_d1]
# number of threads collapsed to 1d layout
nb_threads = sg_layout[0] * sg_layout[1] * nb_workitems

prefetch_a_d0 = as_const_or_as_knob(prefetch_a_d0, "prefetch_a_d0")
prefetch_a_d1 = as_const_or_as_knob(prefetch_a_d1, "prefetch_a_d1")
prefetch_tile_a = [prefetch_a_d0, prefetch_a_d1]
prefetch_b_d0 = as_const_or_as_knob(prefetch_b_d0, "prefetch_b_d0")
prefetch_b_d1 = as_const_or_as_knob(prefetch_b_d1, "prefetch_b_d1")
prefetch_tile_b = [prefetch_b_d0, prefetch_b_d1]
k_tile = as_const_or_as_knob(k_tile, "k_tile")

if use_knobs:
constraint2 = transform_smt.constrain_params(
(any_param, any_param, any_param, any_param),
(
wg_d0,
wg_d1,
k_tile,
prefetch_a_d0,
prefetch_a_d1,
prefetch_b_d0,
prefetch_b_d1,
),
[smt_int] * 7,
)
with ir.InsertionPoint(constraint2.body):
WGd0, WGd1, K, PFAd0, PFAd1, PFBd0, PFBd1 = constraint2.body.arguments
C0 = smt.int_constant(c0)
smt.assert_(smt.eq((smt.int_mod(WGd0, PFAd0), C0)))
smt.assert_(smt.eq((smt.int_mod(K, PFAd1), C0)))
PFAd0_step = smt.int_div(WGd0, PFAd0)
PFAd1_step = smt.int_div(K, PFAd1)

smt.assert_(smt.eq((smt.int_mod(K, PFBd0), C0)))
smt.assert_(smt.eq((smt.int_mod(WGd1, PFBd1), C0)))
PFBd0_step = smt.int_div(K, PFBd0)
PFBd1_step = smt.int_div(WGd1, PFBd1)

smt.yield_((PFAd0_step, PFAd1_step, PFBd0_step, PFBd1_step))
prefetch_layout_a = constraint2.results[0:2]
prefetch_layout_b = constraint2.results[2:4]
else:
prefetch_layout_a = [
wg_d0 // prefetch_a_d0,
k_tile // prefetch_a_d1,
]
prefetch_layout_b = [
k_tile // prefetch_b_d0,
wg_d1 // prefetch_b_d1,
]

# matmul matrix shapes
sg_tile_a = [sg_tile[0], k_tile]
sg_tile_b = [k_tile, sg_tile[1]]
sg_tile_a = [sg_d0, k_tile]
sg_tile_b = [k_tile, sg_d1]

load_a_d0 = as_const_or_as_knob(load_a_d0, "load_a_d0")
load_a_d1 = as_const_or_as_knob(load_a_d1, "load_a_d1")
load_b_d0 = as_const_or_as_knob(load_b_d0, "load_b_d0")
load_b_d1 = as_const_or_as_knob(load_b_d1, "load_b_d1")

load_tile_a = [load_a_d0, load_a_d1]
load_tile_b = [load_b_d0, load_b_d1]

if stop_at_stage == "initial":
raise PipelineInterrupt()
Expand Down
21 changes: 21 additions & 0 deletions lighthouse/tune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
__all__ = ["rewrite", "smt"]

import sys
import importlib


def __getattr__(name):
"""Enable lazy loading of submodules.

Enables `import lighthouse.tune as lh_tune; lh_tune.<submodule>` with
loading of (the submodule's heavy) depenendencies only upon being needed.
"""

if name in __all__:
# Import the submodule and cache it on the current module. That is,
# upon the next access __getattr__ will not be called.
submodule = importlib.import_module("lighthouse.tune." + name)
lighthouse_tune_mod = sys.modules[__name__]
setattr(lighthouse_tune_mod, name, submodule)
return submodule
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
Loading