diff --git a/pkg/sql/plan/base_binder.go b/pkg/sql/plan/base_binder.go index e08d17b6e151c..366650a960bac 100644 --- a/pkg/sql/plan/base_binder.go +++ b/pkg/sql/plan/base_binder.go @@ -835,6 +835,11 @@ func (b *baseBinder) bindComparisonExpr(astExpr *tree.ComparisonExpr, depth int3 return b.bindFuncExprImplByAstExpr("not", []tree.Expr{newExpr}, depth) case tree.IN: + if leftTuple, ok := astExpr.Left.(*tree.Tuple); ok { + if rightTuple, ok := astExpr.Right.(*tree.Tuple); ok { + return b.bindTupleInByAst(leftTuple, rightTuple, depth, false) + } + } switch r := astExpr.Right.(type) { case *tree.Tuple: op = "in" @@ -879,6 +884,11 @@ func (b *baseBinder) bindComparisonExpr(astExpr *tree.ComparisonExpr, depth int3 } case tree.NOT_IN: + if leftTuple, ok := astExpr.Left.(*tree.Tuple); ok { + if rightTuple, ok := astExpr.Right.(*tree.Tuple); ok { + return b.bindTupleInByAst(leftTuple, rightTuple, depth, true) + } + } switch astExpr.Right.(type) { case *tree.Tuple: op = "not_in" @@ -977,6 +987,51 @@ func (b *baseBinder) bindComparisonExpr(astExpr *tree.ComparisonExpr, depth int3 return b.bindFuncExprImplByAstExpr(op, []tree.Expr{astExpr.Left, astExpr.Right}, depth) } +func (b *baseBinder) bindTupleInByAst(leftTuple *tree.Tuple, rightTuple *tree.Tuple, depth int32, isNot bool) (*plan.Expr, error) { + var newExpr *plan.Expr + + for _, rightVal := range rightTuple.Exprs { + rightTupleVal, ok := rightVal.(*tree.Tuple) + if !ok { + return nil, moerr.NewInternalError(b.GetContext(), "IN list must contain tuples") + } + if len(leftTuple.Exprs) != len(rightTupleVal.Exprs) { + return nil, moerr.NewInternalError(b.GetContext(), "tuple length mismatch") + } + + var andExpr *plan.Expr + for i := 0; i < len(leftTuple.Exprs); i++ { + eqExpr, err := b.bindFuncExprImplByAstExpr("=", []tree.Expr{leftTuple.Exprs[i], rightTupleVal.Exprs[i]}, depth) + if err != nil { + return nil, err + } + if andExpr == nil { + andExpr = eqExpr + } else { + andExpr, err = BindFuncExprImplByPlanExpr(b.GetContext(), "and", []*plan.Expr{andExpr, eqExpr}) + if err != nil { + return nil, err + } + } + } + + if newExpr == nil { + newExpr = andExpr + } else { + var err error + newExpr, err = BindFuncExprImplByPlanExpr(b.GetContext(), "or", []*plan.Expr{newExpr, andExpr}) + if err != nil { + return nil, err + } + } + } + + if isNot { + return BindFuncExprImplByPlanExpr(b.GetContext(), "not", []*plan.Expr{newExpr}) + } + return newExpr, nil +} + func (b *baseBinder) bindFuncExpr(astExpr *tree.FuncExpr, depth int32, isRoot bool) (*Expr, error) { funcRef, ok := astExpr.Func.FunctionReference.(*tree.UnresolvedName) if !ok { diff --git a/pkg/sql/plan/query_builder_test.go b/pkg/sql/plan/query_builder_test.go index d2596670b7e78..d04bb8cf540ee 100644 --- a/pkg/sql/plan/query_builder_test.go +++ b/pkg/sql/plan/query_builder_test.go @@ -1551,6 +1551,39 @@ func TestBaseBinder_bindComparisonExpr(t *testing.T) { require.Equal(t, "not_in", funcExpr.F.Func.ObjName) }, }, + { + name: "Tuple IN: (a, b) IN ((1, 2), (3, 4))", + sql: "(a, b) IN ((1, 2), (3, 4))", + expectErr: false, + checkFunc: func(t *testing.T, expr *plan.Expr, err error) { + require.NoError(t, err) + require.NotNil(t, expr) + funcExpr, ok := expr.Expr.(*plan.Expr_F) + require.True(t, ok) + require.Equal(t, "or", funcExpr.F.Func.ObjName) + }, + }, + { + name: "Tuple NOT IN: (a, b) NOT IN ((1, 2), (3, 4))", + sql: "(a, b) NOT IN ((1, 2), (3, 4))", + expectErr: false, + checkFunc: func(t *testing.T, expr *plan.Expr, err error) { + require.NoError(t, err) + require.NotNil(t, expr) + funcExpr, ok := expr.Expr.(*plan.Expr_F) + require.True(t, ok) + require.Equal(t, "not", funcExpr.F.Func.ObjName) + }, + }, + { + name: "Tuple IN length mismatch: (a, b) IN ((1, 2, 3))", + sql: "(a, b) IN ((1, 2, 3))", + expectErr: true, + checkFunc: func(t *testing.T, expr *plan.Expr, err error) { + require.Error(t, err) + require.Contains(t, err.Error(), "tuple length mismatch") + }, + }, // Tuple comparisons { name: "Tuple EQUAL: (a, b) = (1, 2)", diff --git a/test/distributed/cases/operator/row_constructor.result b/test/distributed/cases/operator/row_constructor.result index 93c61fa928121..438e7245743ac 100644 --- a/test/distributed/cases/operator/row_constructor.result +++ b/test/distributed/cases/operator/row_constructor.result @@ -57,6 +57,36 @@ select (-2,1,3) >= (-1,2,3); select (-387293.324321,32190391.34134,000) <= (-387293.324321, -123, -1); (-387293.324321, 32190391.34134, 0) <= (-387293.324321, -123, -1) 0 +select (1,2) in ((1,2),(3,4)); +(1, 2) in ((1, 2), (3, 4)) +1 +select (1,2) in ((1,3),(3,4)); +(1, 2) in ((1, 3), (3, 4)) +0 +select (1,2) in ((1,2)); +(1, 2) in ((1, 2)) +1 +select (1,2) not in ((1,2)); +(1, 2) not in ((1, 2)) +0 +select (1,2) not in ((1,3),(3,4)); +(1, 2) not in ((1, 3), (3, 4)) +1 +select (1,2) in ((1,null)); +(1, 2) in ((1, null)) +null +select (1,2) in ((1,null),(1,2)); +(1, 2) in ((1, null), (1, 2)) +1 +select (1,2) not in ((1,null)); +(1, 2) not in ((1, null)) +null +select (1,null) in ((1,null)); +(1, null) in ((1, null)) +null +select (1,2) not in ((1,null),(2,2)); +(1, 2) not in ((1, null), (2, 2)) +null select (1,2,3) > (-1,-3+2,2*3); (1, 2, 3) > (-1, -3 + 2, 2 * 3) 1 diff --git a/test/distributed/cases/operator/row_constructor.sql b/test/distributed/cases/operator/row_constructor.sql index 67c925d3e91e8..b94187632b843 100644 --- a/test/distributed/cases/operator/row_constructor.sql +++ b/test/distributed/cases/operator/row_constructor.sql @@ -25,6 +25,16 @@ select (1,null) < (2,null); select (2,3) >= (1,3); select (-2,1,3) >= (-1,2,3); select (-387293.324321,32190391.34134,000) <= (-387293.324321, -123, -1); +select (1,2) in ((1,2),(3,4)); +select (1,2) in ((1,3),(3,4)); +select (1,2) in ((1,2)); +select (1,2) not in ((1,2)); +select (1,2) not in ((1,3),(3,4)); +select (1,2) in ((1,null)); +select (1,2) in ((1,null),(1,2)); +select (1,2) not in ((1,null)); +select (1,null) in ((1,null)); +select (1,2) not in ((1,null),(2,2)); -- + - * / % mod select (1,2,3) > (-1,-3+2,2*3);