diff --git a/KLR/Core/Basic.lean b/KLR/Core/Basic.lean index 54671af5..3c67f158 100644 --- a/KLR/Core/Basic.lean +++ b/KLR/Core/Basic.lean @@ -188,6 +188,7 @@ partial def operatorBasicTensors : Operator → List TensorRef | .nonzeroWithCount n => [n.dst, n.src] | .devicePrint t => [t.src] | .exponential e => [e.dst, e.src] + | .activate2 a => [a.dst, a.src] partial def operatorAdditionalTensors : Operator → List TensorName | .ncActivate d => (tensors d.scale) ++ (tensors d.bias) ++ (tensors d.reduceRes) @@ -213,7 +214,7 @@ partial def operatorAdditionalTensors : Operator → List TensorName | .tensorScalarCumulative t => (tensors t.imm0) ++ (tensors t.imm1) | .ncNGather _ => [] | .nonzeroWithCount _ => [] - | .exponential e => (tensors e.maxValue) ++ (tensors e.ReduceInit) + | .exponential e => (tensors e.maxValue) ++ (tensors e.reduceRes) ++ (tensors e.reduceInit) | _ => [] instance : Tensors Operator where diff --git a/KLR/Core/Operators.lean b/KLR/Core/Operators.lean index 3d104a01..3c9ebcdc 100644 --- a/KLR/Core/Operators.lean +++ b/KLR/Core/Operators.lean @@ -1219,8 +1219,9 @@ structure Exponential where dst : TensorRef src : TensorRef maxValue : Operand + reduceRes : Option TensorRef reducecmd : AccumCmd - ReduceInit : Operand + reduceInit : Operand deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp instance : MapTensorRefs Exponential where @@ -1228,10 +1229,39 @@ instance : MapTensorRefs Exponential where dst := ← ft op.dst, src := ← ft op.src, maxValue := ← fo op.maxValue, - ReduceInit := ← fo op.ReduceInit + reduceRes := ← op.reduceRes.mapM ft, + reduceInit := ← fo op.reduceInit } @[serde tag = 217] +structure Activate2 where + dst : TensorRef + src : TensorRef + op0 : AluOp + op1 : AluOp + imm0 : Operand + imm1 : Operand + activationFunc : ActivationFunc + reluParam : Operand + reduceOp : AluOp + reduceRes : Option TensorRef + reduceCmd : AccumCmd + reverse0 : Bool + reverse1 : Bool + dtype : Option Dtype + deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp + +instance : MapTensorRefs Activate2 where + mapM ft fo op := do pure { op with + dst := ← ft op.dst, + src := ← ft op.src, + imm0 := ← fo op.imm0, + imm1 := ← fo op.imm1, + reluParam := ← fo op.reluParam, + reduceRes := ← op.reduceRes.mapM ft + } + +@[serde tag = 218] inductive Operator where | activate (op : Activate) | ncActivate (op : NcActivate) @@ -1308,9 +1338,10 @@ inductive Operator where | nonzeroWithCount (op: NonzeroWithCount) | devicePrint (op: DevicePrint) | exponential(op: Exponential) + | activate2 (op: Activate2) deriving BEq, FromCBOR, FromJson, FromSexp, Repr, ToCBOR, ToJson, ToSexp -@[serde tag = 218] +@[serde tag = 219] inductive TGROperator where | activate (op : Activate) | affineSelect (op : AffineSelect) @@ -1419,3 +1450,4 @@ instance : MapTensorRefs Operator where | .nonzeroWithCount op => return .nonzeroWithCount (← MapTensorRefs.mapM ft fo op) | .devicePrint op => return .devicePrint (← MapTensorRefs.mapM ft fo op) | .exponential op => return .exponential (← MapTensorRefs.mapM ft fo op) + | .activate2 op => return .activate2 (← MapTensorRefs.mapM ft fo op) diff --git a/KLR/Extract/Extract/Basic.lean b/KLR/Extract/Extract/Basic.lean index 6d8b9a8d..b0c6bc74 100644 --- a/KLR/Extract/Extract/Basic.lean +++ b/KLR/Extract/Extract/Basic.lean @@ -375,6 +375,7 @@ def klrAST: MetaM (List LeanType) := do `KLR.Core.PrintOutputBuffer, `KLR.Core.DevicePrint, `KLR.Core.Exponential, + `KLR.Core.Activate2, `KLR.Core.Operator, -- Core.Basic `KLR.Core.Stmt, diff --git a/KLR/Trace/ISA.lean b/KLR/Trace/ISA.lean index a9faa062..12b44496 100644 --- a/KLR/Trace/ISA.lean +++ b/KLR/Trace/ISA.lean @@ -1235,6 +1235,7 @@ nki builtin.isa.exponential (dst : Access) (src : Access) (max_value : Sum Immediate Access := .inl $ .float 0.0) + (reduce_res : Option Access := none) (reduce_cmd : AccumCmd := .Idle) (reduce_init : Sum Immediate Access := .inl $ .float 0.0) (mask : Option Immediate := none) @@ -1246,9 +1247,52 @@ nki builtin.isa.exponential maxValue := match max_value with | .inl imm => .imm imm | .inr t => .tile $ .abstract t, + reduceRes := reduce_res.map .abstract, reducecmd := reduce_cmd, - ReduceInit := match reduce_init with + reduceInit := match reduce_init with | .inl imm => .imm imm | .inr t => .tile $ .abstract t }) name return .none + +nki builtin.isa.activate2 + (dst : Access) + (src : Access) + (op0 : AluOp) + (op1 : AluOp) + (imm0 : Sum Immediate Access) + (imm1 : Sum Immediate Access) + (op : ActivationFunc) + -- kwargs + (relu_param : Sum Immediate Access := .inl $ .float 0.0) + (reduce_op : AluOp := .bypass) + (reduce_res : Option Access := none) + (reduce_cmd : AccumCmd := .Idle) + (reverse0 : Bool := false) + (reverse1 : Bool := false) + (mask : Option Immediate := none) + (name : Option String := none) := do + if mask.isSome then throw maskNotSupported + Trace.add_stmt $ .oper (.activate2 { + dst := .abstract dst, + src := .abstract src, + op0 := op0, + op1 := op1, + imm0 := match imm0 with + | .inl imm => .imm imm + | .inr t => .tile $ .abstract t, + imm1 := match imm1 with + | .inl imm => .imm imm + | .inr t => .tile $ .abstract t, + activationFunc := op, + reluParam := match relu_param with + | .inl imm => .imm imm + | .inr t => .tile $ .abstract t, + reduceOp := reduce_op, + reduceRes := reduce_res.map .abstract, + reduceCmd := reduce_cmd, + reverse0 := reverse0, + reverse1 := reverse1, + dtype := dst.tensor.dtype + }) name + return .none diff --git a/interop/klr/NKI.asdl b/interop/klr/NKI.asdl index 6d96f57c..ef9da411 100644 --- a/interop/klr/NKI.asdl +++ b/interop/klr/NKI.asdl @@ -363,7 +363,9 @@ PrintOutputBuffer = DevicePrint = (TensorRef src, String printPrefix, PrintOutputBuffer buffer) -Exponential = (TensorRef dst, TensorRef src, Operand maxValue, AccumCmd reducecmd, Operand ReduceInit) +Exponential = (TensorRef dst, TensorRef src, Operand maxValue, TensorRef? reduceRes, AccumCmd reducecmd, Operand reduceInit) + +Activate2 = (TensorRef dst, TensorRef src, AluOp op0, AluOp op1, Operand imm0, Operand imm1, ActivationFunc activationFunc, Operand reluParam, AluOp reduceOp, TensorRef? reduceRes, AccumCmd reduceCmd, Bool reverse0, Bool reverse1, Dtype? dtype) Operator = | activate(Activate op) | ncActivate(NcActivate op) @@ -440,6 +442,7 @@ Operator = | nonzeroWithCount(NonzeroWithCount op) | devicePrint(DevicePrint op) | exponential(Exponential op) + | activate2(Activate2 op) Stmt = | oper(Operator op, String? name, Pos pos) diff --git a/interop/klr/klir_ast.hpp b/interop/klr/klir_ast.hpp index f25e487c..baf4be72 100644 --- a/interop/klr/klir_ast.hpp +++ b/interop/klr/klir_ast.hpp @@ -1065,8 +1065,26 @@ struct Exponential final { Ptr dst; Ptr src; Ptr maxValue; + Option> reduceRes; AccumCmd reducecmd; - Ptr ReduceInit; + Ptr reduceInit; +}; + +struct Activate2 final { + Ptr dst; + Ptr src; + AluOp op0; + AluOp op1; + Ptr imm0; + Ptr imm1; + ActivationFunc activationFunc; + Ptr reluParam; + AluOp reduceOp; + Option> reduceRes; + AccumCmd reduceCmd; + Bool reverse0; + Bool reverse1; + Option dtype; }; struct Operator { @@ -1146,6 +1164,7 @@ struct Operator { nonzeroWithCount, devicePrint, exponential, + activate2, }; Tag tag; Operator(Tag tag) : tag(tag) {} @@ -1531,6 +1550,11 @@ struct OperatorExponentialWrapper final : Operator { OperatorExponentialWrapper() : Operator(Tag::exponential) {} }; +struct OperatorActivate2Wrapper final : Operator { + Ptr op; + OperatorActivate2Wrapper() : Operator(Tag::activate2) {} +}; + struct Stmt { enum class Tag { oper = 1, diff --git a/interop/klr/klir_pretty_print.cpp b/interop/klr/klir_pretty_print.cpp index db08694b..42ce6d40 100644 --- a/interop/klr/klir_pretty_print.cpp +++ b/interop/klr/klir_pretty_print.cpp @@ -3139,11 +3139,74 @@ std::string to_string(Exponential &ExponentialInstance) { result += "maxValue="; result += to_string(*(ExponentialInstance.maxValue.get())); result += ", "; + result += "reduceRes="; + if (ExponentialInstance.reduceRes.has_value()) { + result += to_string(*(ExponentialInstance.reduceRes.value().get())); + } else { + result += "None"; + } + result += ", "; result += "reducecmd="; result += to_string(ExponentialInstance.reducecmd); // mapped from enum result += ", "; - result += "ReduceInit="; - result += to_string(*(ExponentialInstance.ReduceInit.get())); + result += "reduceInit="; + result += to_string(*(ExponentialInstance.reduceInit.get())); + result += ")"; + return result; +}; + +std::string to_string(Activate2 &Activate2Instance) { + std::string result; + result += "Activate2("; + result += "dst="; + result += to_string(*(Activate2Instance.dst.get())); + result += ", "; + result += "src="; + result += to_string(*(Activate2Instance.src.get())); + result += ", "; + result += "op0="; + result += to_string(Activate2Instance.op0); // mapped from enum + result += ", "; + result += "op1="; + result += to_string(Activate2Instance.op1); // mapped from enum + result += ", "; + result += "imm0="; + result += to_string(*(Activate2Instance.imm0.get())); + result += ", "; + result += "imm1="; + result += to_string(*(Activate2Instance.imm1.get())); + result += ", "; + result += "activationFunc="; + result += to_string(Activate2Instance.activationFunc); // mapped from enum + result += ", "; + result += "reluParam="; + result += to_string(*(Activate2Instance.reluParam.get())); + result += ", "; + result += "reduceOp="; + result += to_string(Activate2Instance.reduceOp); // mapped from enum + result += ", "; + result += "reduceRes="; + if (Activate2Instance.reduceRes.has_value()) { + result += to_string(*(Activate2Instance.reduceRes.value().get())); + } else { + result += "None"; + } + result += ", "; + result += "reduceCmd="; + result += to_string(Activate2Instance.reduceCmd); // mapped from enum + result += ", "; + result += "reverse0="; + result += std::to_string(Activate2Instance.reverse0); + result += ", "; + result += "reverse1="; + result += std::to_string(Activate2Instance.reverse1); + result += ", "; + result += "dtype="; + if (Activate2Instance.dtype.has_value()) { + result += to_string(Activate2Instance.dtype.value()); // mapped from enum + } else { + result += "None"; + } result += ")"; return result; }; @@ -3814,6 +3877,15 @@ to_string(OperatorExponentialWrapper &OperatorExponentialWrapperInstance) { result += ")"; return result; }; +std::string +to_string(OperatorActivate2Wrapper &OperatorActivate2WrapperInstance) { + std::string result; + result += "OperatorActivate2Wrapper("; + result += "op="; + result += to_string(*(OperatorActivate2WrapperInstance.op.get())); + result += ")"; + return result; +}; std::string to_string(Operator &OperatorInstance) { switch (OperatorInstance.tag) { case (Operator::Tag::activate): { @@ -4193,6 +4265,11 @@ std::string to_string(Operator &OperatorInstance) { static_cast(OperatorInstance); return to_string(derivedRef); } + case (Operator::Tag::activate2): { + OperatorActivate2Wrapper &derivedRef = + static_cast(OperatorInstance); + return to_string(derivedRef); + } default: return "UNABLE TO PRINT"; } diff --git a/interop/klr/klir_pretty_print.hpp b/interop/klr/klir_pretty_print.hpp index e49da8d5..856ffed0 100644 --- a/interop/klr/klir_pretty_print.hpp +++ b/interop/klr/klir_pretty_print.hpp @@ -266,6 +266,8 @@ std::string to_string(DevicePrint &DevicePrintInstance); std::string to_string(Exponential &ExponentialInstance); +std::string to_string(Activate2 &Activate2Instance); + std::string to_string(OperatorActivateWrapper &OperatorActivateWrapperInstance); std::string to_string(OperatorNcActivateWrapper &OperatorNcActivateWrapperInstance); @@ -399,6 +401,8 @@ std::string to_string(OperatorDevicePrintWrapper &OperatorDevicePrintWrapperInstance); std::string to_string(OperatorExponentialWrapper &OperatorExponentialWrapperInstance); +std::string +to_string(OperatorActivate2Wrapper &OperatorActivate2WrapperInstance); std::string to_string(Operator &OperatorInstance); std::string to_string(StmtOperWrapper &StmtOperWrapperInstance); diff --git a/interop/klr/klir_serde.cpp b/interop/klr/klir_serde.cpp index c1489e0c..52eb610c 100644 --- a/interop/klr/klir_serde.cpp +++ b/interop/klr/klir_serde.cpp @@ -2880,7 +2880,7 @@ bool DevicePrint_ser(FILE *out, const Ptr &value) { } bool Exponential_ser(FILE *out, const Ptr &value) { - if (!serialize_tag(out, 216, 0, 5)) + if (!serialize_tag(out, 216, 0, 6)) return false; if (!TensorRef_ser(out, value->dst)) return false; @@ -2888,9 +2888,45 @@ bool Exponential_ser(FILE *out, const Ptr &value) { return false; if (!Operand_ser(out, value->maxValue)) return false; + if (!Option_TensorRef_ser(out, value->reduceRes)) + return false; if (!AccumCmd_ser(out, value->reducecmd)) return false; - if (!Operand_ser(out, value->ReduceInit)) + if (!Operand_ser(out, value->reduceInit)) + return false; + return true; +} + +bool Activate2_ser(FILE *out, const Ptr &value) { + if (!serialize_tag(out, 217, 0, 14)) + return false; + if (!TensorRef_ser(out, value->dst)) + return false; + if (!TensorRef_ser(out, value->src)) + return false; + if (!AluOp_ser(out, value->op0)) + return false; + if (!AluOp_ser(out, value->op1)) + return false; + if (!Operand_ser(out, value->imm0)) + return false; + if (!Operand_ser(out, value->imm1)) + return false; + if (!ActivationFunc_ser(out, value->activationFunc)) + return false; + if (!Operand_ser(out, value->reluParam)) + return false; + if (!AluOp_ser(out, value->reduceOp)) + return false; + if (!Option_TensorRef_ser(out, value->reduceRes)) + return false; + if (!AccumCmd_ser(out, value->reduceCmd)) + return false; + if (!Bool_ser(out, value->reverse0)) + return false; + if (!Bool_ser(out, value->reverse1)) + return false; + if (!Option_Dtype_ser(out, value->dtype)) return false; return true; } @@ -3201,13 +3237,17 @@ bool Operator_ser(FILE *out, const Ptr &value) { tag_val = 74; field_count = 1; break; + case Operator::Tag::activate2: + tag_val = 75; + field_count = 1; + break; default: throw std::runtime_error("Unknown Operator type in serialization"); return false; } // Serialize the tag - if (!serialize_tag(out, 217, tag_val, field_count)) + if (!serialize_tag(out, 218, tag_val, field_count)) return false; // Serialize the fields based on the specific variant @@ -3579,6 +3619,11 @@ bool Operator_ser(FILE *out, const Ptr &value) { static_cast(value.get()); return Exponential_ser(out, typed_value->op); } + case Operator::Tag::activate2: { + auto *typed_value = + static_cast(value.get()); + return Activate2_ser(out, typed_value->op); + } default: throw std::runtime_error("Unknown Operator type in serialization"); return false; @@ -6974,9 +7019,9 @@ Ptr Exponential_des(FILE *in) { msg << "Could not find tag, expecting Exponential:216,0"; throw std::runtime_error(msg.str()); } - if (t != 216 || c != 0 || l != 5) { + if (t != 216 || c != 0 || l != 6) { std::ostringstream msg; - msg << "Expecting Exponential:(216,0,5)"; + msg << "Expecting Exponential:(216,0,6)"; msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")"; throw std::runtime_error(msg.str()); } @@ -6984,8 +7029,40 @@ Ptr Exponential_des(FILE *in) { x->dst = TensorRef_des(in); x->src = TensorRef_des(in); x->maxValue = Operand_des(in); + x->reduceRes = Option_TensorRef_des(in); x->reducecmd = AccumCmd_des(in); - x->ReduceInit = Operand_des(in); + x->reduceInit = Operand_des(in); + return x; +} + +Ptr Activate2_des(FILE *in) { + u8 t, c, l; + if (!deserialize_tag(in, &t, &c, &l)) { + std::ostringstream msg; + msg << "Could not find tag, expecting Activate2:217,0"; + throw std::runtime_error(msg.str()); + } + if (t != 217 || c != 0 || l != 14) { + std::ostringstream msg; + msg << "Expecting Activate2:(217,0,14)"; + msg << " got:(" << (int)t << "," << (int)c << "," << (int)l << ")"; + throw std::runtime_error(msg.str()); + } + Ptr x = ptr(); + x->dst = TensorRef_des(in); + x->src = TensorRef_des(in); + x->op0 = AluOp_des(in); + x->op1 = AluOp_des(in); + x->imm0 = Operand_des(in); + x->imm1 = Operand_des(in); + x->activationFunc = ActivationFunc_des(in); + x->reluParam = Operand_des(in); + x->reduceOp = AluOp_des(in); + x->reduceRes = Option_TensorRef_des(in); + x->reduceCmd = AccumCmd_des(in); + x->reverse0 = Bool_des(in); + x->reverse1 = Bool_des(in); + x->dtype = Option_Dtype_des(in); return x; } @@ -6993,7 +7070,7 @@ Ptr Operator_des(FILE *in) { u8 t, c, l; if (!deserialize_tag(in, &t, &c, &l)) throw std::runtime_error("Could not read tag"); - if (t != 217) + if (t != 218) throw std::runtime_error("Unexpected type tag"); switch (c) { case 0: { @@ -7611,6 +7688,14 @@ Ptr Operator_des(FILE *in) { return x; break; } + case 75: { + if (l != 1) + throw std::runtime_error("Wrong number of elements"); + Ptr x = ptr(); + x->op = Activate2_des(in); + return x; + break; + } default: throw std::runtime_error("Invalid value tag"); } diff --git a/interop/klr/klir_serde.hpp b/interop/klr/klir_serde.hpp index ae3eaec9..cbfc96f6 100644 --- a/interop/klr/klir_serde.hpp +++ b/interop/klr/klir_serde.hpp @@ -177,6 +177,7 @@ bool NonzeroWithCount_ser(FILE *out, const Ptr &value); bool PrintOutputBuffer_ser(FILE *out, const PrintOutputBuffer &value); bool DevicePrint_ser(FILE *out, const Ptr &value); bool Exponential_ser(FILE *out, const Ptr &value); +bool Activate2_ser(FILE *out, const Ptr &value); bool Operator_ser(FILE *out, const Ptr &value); bool Stmt_ser(FILE *out, const Ptr &value); bool Block_ser(FILE *out, const Ptr &value); @@ -308,6 +309,7 @@ Ptr NonzeroWithCount_des(FILE *in); PrintOutputBuffer PrintOutputBuffer_des(FILE *in); Ptr DevicePrint_des(FILE *in); Ptr Exponential_des(FILE *in); +Ptr Activate2_des(FILE *in); Ptr Operator_des(FILE *in); Ptr Stmt_des(FILE *in); Ptr Block_des(FILE *in);