Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion KLR/Core/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ partial def operatorBasicTensors : Operator → List TensorRef
| .nonzeroWithCount n => [n.dst, n.src]
| .devicePrint t => [t.src]
| .exponential e => [e.dst, e.src]
| .activate2 a => [a.dst, a.src]

partial def operatorAdditionalTensors : Operator → List TensorName
| .ncActivate d => (tensors d.scale) ++ (tensors d.bias) ++ (tensors d.reduceRes)
Expand All @@ -213,7 +214,7 @@ partial def operatorAdditionalTensors : Operator → List TensorName
| .tensorScalarCumulative t => (tensors t.imm0) ++ (tensors t.imm1)
| .ncNGather _ => []
| .nonzeroWithCount _ => []
| .exponential e => (tensors e.maxValue) ++ (tensors e.ReduceInit)
| .exponential e => (tensors e.maxValue) ++ (tensors e.reduceRes) ++ (tensors e.reduceInit)
| _ => []

instance : Tensors Operator where
Expand Down
38 changes: 35 additions & 3 deletions KLR/Core/Operators.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1219,19 +1219,49 @@ structure Exponential where
dst : TensorRef
src : TensorRef
maxValue : Operand
reduceRes : Option TensorRef
reducecmd : AccumCmd
ReduceInit : Operand
reduceInit : Operand
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp

instance : MapTensorRefs Exponential where
mapM ft fo op := do pure { op with
dst := ← ft op.dst,
src := ← ft op.src,
maxValue := ← fo op.maxValue,
ReduceInit := ← fo op.ReduceInit
reduceRes := ← op.reduceRes.mapM ft,
reduceInit := ← fo op.reduceInit
}

@[serde tag = 217]
structure Activate2 where
dst : TensorRef
src : TensorRef
op0 : AluOp
op1 : AluOp
imm0 : Operand
imm1 : Operand
activationFunc : ActivationFunc
reluParam : Operand
reduceOp : AluOp
reduceRes : Option TensorRef
reduceCmd : AccumCmd
reverse0 : Bool
reverse1 : Bool
dtype : Option Dtype
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp

instance : MapTensorRefs Activate2 where
mapM ft fo op := do pure { op with
dst := ← ft op.dst,
src := ← ft op.src,
imm0 := ← fo op.imm0,
imm1 := ← fo op.imm1,
reluParam := ← fo op.reluParam,
reduceRes := ← op.reduceRes.mapM ft
}

@[serde tag = 218]
inductive Operator where
| activate (op : Activate)
| ncActivate (op : NcActivate)
Expand Down Expand Up @@ -1308,9 +1338,10 @@ inductive Operator where
| nonzeroWithCount (op: NonzeroWithCount)
| devicePrint (op: DevicePrint)
| exponential(op: Exponential)
| activate2 (op: Activate2)
deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp

@[serde tag = 218]
@[serde tag = 219]
inductive TGROperator where
| activate (op : Activate)
| affineSelect (op : AffineSelect)
Expand Down Expand Up @@ -1419,3 +1450,4 @@ instance : MapTensorRefs Operator where
| .nonzeroWithCount op => return .nonzeroWithCount (← MapTensorRefs.mapM ft fo op)
| .devicePrint op => return .devicePrint (← MapTensorRefs.mapM ft fo op)
| .exponential op => return .exponential (← MapTensorRefs.mapM ft fo op)
| .activate2 op => return .activate2 (← MapTensorRefs.mapM ft fo op)
1 change: 1 addition & 0 deletions KLR/Extract/Extract/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def klrAST: MetaM (List LeanType) := do
`KLR.Core.PrintOutputBuffer,
`KLR.Core.DevicePrint,
`KLR.Core.Exponential,
`KLR.Core.Activate2,
`KLR.Core.Operator,
-- Core.Basic
`KLR.Core.Stmt,
Expand Down
46 changes: 45 additions & 1 deletion KLR/Trace/ISA.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,7 @@ nki builtin.isa.exponential
(dst : Access)
(src : Access)
(max_value : Sum Immediate Access := .inl $ .float 0.0)
(reduce_res : Option Access := none)
(reduce_cmd : AccumCmd := .Idle)
(reduce_init : Sum Immediate Access := .inl $ .float 0.0)
(mask : Option Immediate := none)
Expand All @@ -1246,9 +1247,52 @@ nki builtin.isa.exponential
maxValue := match max_value with
| .inl imm => .imm imm
| .inr t => .tile $ .abstract t,
reduceRes := reduce_res.map .abstract,
reducecmd := reduce_cmd,
ReduceInit := match reduce_init with
reduceInit := match reduce_init with
| .inl imm => .imm imm
| .inr t => .tile $ .abstract t
}) name
return .none

nki builtin.isa.activate2
(dst : Access)
(src : Access)
(op0 : AluOp)
(op1 : AluOp)
(imm0 : Sum Immediate Access)
(imm1 : Sum Immediate Access)
(op : ActivationFunc)
-- kwargs
(relu_param : Sum Immediate Access := .inl $ .float 0.0)
(reduce_op : AluOp := .bypass)
(reduce_res : Option Access := none)
(reduce_cmd : AccumCmd := .Idle)
(reverse0 : Bool := false)
(reverse1 : Bool := false)
(mask : Option Immediate := none)
(name : Option String := none) := do
if mask.isSome then throw maskNotSupported
Trace.add_stmt $ .oper (.activate2 {
dst := .abstract dst,
src := .abstract src,
op0 := op0,
op1 := op1,
imm0 := match imm0 with
| .inl imm => .imm imm
| .inr t => .tile $ .abstract t,
imm1 := match imm1 with
| .inl imm => .imm imm
| .inr t => .tile $ .abstract t,
activationFunc := op,
reluParam := match relu_param with
| .inl imm => .imm imm
| .inr t => .tile $ .abstract t,
reduceOp := reduce_op,
reduceRes := reduce_res.map .abstract,
reduceCmd := reduce_cmd,
reverse0 := reverse0,
reverse1 := reverse1,
dtype := dst.tensor.dtype
}) name
return .none
5 changes: 4 additions & 1 deletion interop/klr/NKI.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,9 @@ PrintOutputBuffer =

DevicePrint = (TensorRef src, String printPrefix, PrintOutputBuffer buffer)

Exponential = (TensorRef dst, TensorRef src, Operand maxValue, AccumCmd reducecmd, Operand ReduceInit)
Exponential = (TensorRef dst, TensorRef src, Operand maxValue, TensorRef? reduceRes, AccumCmd reducecmd, Operand reduceInit)

Activate2 = (TensorRef dst, TensorRef src, AluOp op0, AluOp op1, Operand imm0, Operand imm1, ActivationFunc activationFunc, Operand reluParam, AluOp reduceOp, TensorRef? reduceRes, AccumCmd reduceCmd, Bool reverse0, Bool reverse1, Dtype? dtype)
Operator =
| activate(Activate op)
| ncActivate(NcActivate op)
Expand Down Expand Up @@ -440,6 +442,7 @@ Operator =
| nonzeroWithCount(NonzeroWithCount op)
| devicePrint(DevicePrint op)
| exponential(Exponential op)
| activate2(Activate2 op)

Stmt =
| oper(Operator op, String? name, Pos pos)
Expand Down
26 changes: 25 additions & 1 deletion interop/klr/klir_ast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,8 +1065,26 @@ struct Exponential final {
Ptr<TensorRef> dst;
Ptr<TensorRef> src;
Ptr<Operand> maxValue;
Option<Ptr<TensorRef>> reduceRes;
AccumCmd reducecmd;
Ptr<Operand> ReduceInit;
Ptr<Operand> reduceInit;
};

struct Activate2 final {
Ptr<TensorRef> dst;
Ptr<TensorRef> src;
AluOp op0;
AluOp op1;
Ptr<Operand> imm0;
Ptr<Operand> imm1;
ActivationFunc activationFunc;
Ptr<Operand> reluParam;
AluOp reduceOp;
Option<Ptr<TensorRef>> reduceRes;
AccumCmd reduceCmd;
Bool reverse0;
Bool reverse1;
Option<Dtype> dtype;
};

struct Operator {
Expand Down Expand Up @@ -1146,6 +1164,7 @@ struct Operator {
nonzeroWithCount,
devicePrint,
exponential,
activate2,
};
Tag tag;
Operator(Tag tag) : tag(tag) {}
Expand Down Expand Up @@ -1531,6 +1550,11 @@ struct OperatorExponentialWrapper final : Operator {
OperatorExponentialWrapper() : Operator(Tag::exponential) {}
};

struct OperatorActivate2Wrapper final : Operator {
Ptr<Activate2> op;
OperatorActivate2Wrapper() : Operator(Tag::activate2) {}
};

struct Stmt {
enum class Tag {
oper = 1,
Expand Down
81 changes: 79 additions & 2 deletions interop/klr/klir_pretty_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3139,11 +3139,74 @@ std::string to_string(Exponential &ExponentialInstance) {
result += "maxValue=";
result += to_string(*(ExponentialInstance.maxValue.get()));
result += ", ";
result += "reduceRes=";
if (ExponentialInstance.reduceRes.has_value()) {
result += to_string(*(ExponentialInstance.reduceRes.value().get()));
} else {
result += "None";
}
result += ", ";
result += "reducecmd=";
result += to_string(ExponentialInstance.reducecmd); // mapped from enum
result += ", ";
result += "ReduceInit=";
result += to_string(*(ExponentialInstance.ReduceInit.get()));
result += "reduceInit=";
result += to_string(*(ExponentialInstance.reduceInit.get()));
result += ")";
return result;
};

std::string to_string(Activate2 &Activate2Instance) {
std::string result;
result += "Activate2(";
result += "dst=";
result += to_string(*(Activate2Instance.dst.get()));
result += ", ";
result += "src=";
result += to_string(*(Activate2Instance.src.get()));
result += ", ";
result += "op0=";
result += to_string(Activate2Instance.op0); // mapped from enum
result += ", ";
result += "op1=";
result += to_string(Activate2Instance.op1); // mapped from enum
result += ", ";
result += "imm0=";
result += to_string(*(Activate2Instance.imm0.get()));
result += ", ";
result += "imm1=";
result += to_string(*(Activate2Instance.imm1.get()));
result += ", ";
result += "activationFunc=";
result += to_string(Activate2Instance.activationFunc); // mapped from enum
result += ", ";
result += "reluParam=";
result += to_string(*(Activate2Instance.reluParam.get()));
result += ", ";
result += "reduceOp=";
result += to_string(Activate2Instance.reduceOp); // mapped from enum
result += ", ";
result += "reduceRes=";
if (Activate2Instance.reduceRes.has_value()) {
result += to_string(*(Activate2Instance.reduceRes.value().get()));
} else {
result += "None";
}
result += ", ";
result += "reduceCmd=";
result += to_string(Activate2Instance.reduceCmd); // mapped from enum
result += ", ";
result += "reverse0=";
result += std::to_string(Activate2Instance.reverse0);
result += ", ";
result += "reverse1=";
result += std::to_string(Activate2Instance.reverse1);
result += ", ";
result += "dtype=";
if (Activate2Instance.dtype.has_value()) {
result += to_string(Activate2Instance.dtype.value()); // mapped from enum
} else {
result += "None";
}
result += ")";
return result;
};
Expand Down Expand Up @@ -3814,6 +3877,15 @@ to_string(OperatorExponentialWrapper &OperatorExponentialWrapperInstance) {
result += ")";
return result;
};
std::string
to_string(OperatorActivate2Wrapper &OperatorActivate2WrapperInstance) {
std::string result;
result += "OperatorActivate2Wrapper(";
result += "op=";
result += to_string(*(OperatorActivate2WrapperInstance.op.get()));
result += ")";
return result;
};
std::string to_string(Operator &OperatorInstance) {
switch (OperatorInstance.tag) {
case (Operator::Tag::activate): {
Expand Down Expand Up @@ -4193,6 +4265,11 @@ std::string to_string(Operator &OperatorInstance) {
static_cast<OperatorExponentialWrapper &>(OperatorInstance);
return to_string(derivedRef);
}
case (Operator::Tag::activate2): {
OperatorActivate2Wrapper &derivedRef =
static_cast<OperatorActivate2Wrapper &>(OperatorInstance);
return to_string(derivedRef);
}
default:
return "UNABLE TO PRINT";
}
Expand Down
4 changes: 4 additions & 0 deletions interop/klr/klir_pretty_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ std::string to_string(DevicePrint &DevicePrintInstance);

std::string to_string(Exponential &ExponentialInstance);

std::string to_string(Activate2 &Activate2Instance);

std::string to_string(OperatorActivateWrapper &OperatorActivateWrapperInstance);
std::string
to_string(OperatorNcActivateWrapper &OperatorNcActivateWrapperInstance);
Expand Down Expand Up @@ -399,6 +401,8 @@ std::string
to_string(OperatorDevicePrintWrapper &OperatorDevicePrintWrapperInstance);
std::string
to_string(OperatorExponentialWrapper &OperatorExponentialWrapperInstance);
std::string
to_string(OperatorActivate2Wrapper &OperatorActivate2WrapperInstance);
std::string to_string(Operator &OperatorInstance);

std::string to_string(StmtOperWrapper &StmtOperWrapperInstance);
Expand Down
Loading
Loading