diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5fb0eee7..39d8bfde 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -33,6 +33,8 @@ jobs: with: args: -color + - name: Update Go dependencies + run: go mod tidy - name: Run golangci-lint uses: golangci/golangci-lint-action@v8.0.0 diff --git a/go.mod b/go.mod index 917dbd4e..81a0a586 100644 --- a/go.mod +++ b/go.mod @@ -142,3 +142,5 @@ require ( replace github.com/chzyer/readline => github.com/stackql/readline v0.0.2-alpha05 replace github.com/mattn/go-sqlite3 => github.com/stackql/stackql-go-sqlite3 v1.0.4-stackql + +replace github.com/stackql/stackql-parser => github.com/stackql/stackql-parser v0.0.0-20251202115006-1595204710ca diff --git a/internal/stackql/astanalysis/earlyanalysis/ast_expand.go b/internal/stackql/astanalysis/earlyanalysis/ast_expand.go index 4e03cccc..6f16c2b5 100644 --- a/internal/stackql/astanalysis/earlyanalysis/ast_expand.go +++ b/internal/stackql/astanalysis/earlyanalysis/ast_expand.go @@ -49,6 +49,7 @@ type indirectExpandAstVisitor struct { selectCount int mutateCount int createBuilder []primitivebuilder.Builder + cteRegistry map[string]*sqlparser.CommonTableExpr } func newIndirectExpandAstVisitor( @@ -75,6 +76,7 @@ func newIndirectExpandAstVisitor( tcc: tcc, whereParams: whereParams, indirectionDepth: indirectionDepth, + cteRegistry: make(map[string]*sqlparser.CommonTableExpr), } return rv, nil } @@ -214,6 +216,21 @@ func (v *indirectExpandAstVisitor) Visit(node sqlparser.SQLNode) error { addIf(node.StraightJoinHint, sqlparser.StraightJoinHint) addIf(node.SQLCalcFoundRows, sqlparser.SQLCalcFoundRowsStr) + // Process CTEs (Common Table Expressions) if present + if node.With != nil { + for _, cte := range node.With.CTEs { + cteName := cte.Name.GetRawVal() + v.cteRegistry[cteName] = cte + // Process the CTE's select statement + if cte.Select != nil { + err := cte.Select.Accept(v) + if err != nil { + return err + } + } + } + } + if node.Comments != nil { node.Comments.Accept(v) //nolint:errcheck // future proof } @@ -822,6 +839,13 @@ func (v *indirectExpandAstVisitor) Visit(node sqlparser.SQLNode) error { if node.IsEmpty() { return nil } + // Check if this is a CTE reference + cteName := node.Name.GetRawVal() + if _, isCTE := v.cteRegistry[cteName]; isCTE { + // This is a CTE reference - no further processing needed + // The CTE's select statement has already been processed + return nil + } containsBackendMaterial := v.handlerCtx.GetDBMSInternalRouter().ExprIsRoutable(node) if containsBackendMaterial { v.containsNativeBackendMaterial = true diff --git a/internal/stackql/astanalysis/earlyanalysis/cte_test.go b/internal/stackql/astanalysis/earlyanalysis/cte_test.go new file mode 100644 index 00000000..17577ffa --- /dev/null +++ b/internal/stackql/astanalysis/earlyanalysis/cte_test.go @@ -0,0 +1,228 @@ +package earlyanalysis_test + +import ( + "testing" + + "github.com/stackql/stackql-parser/go/vt/sqlparser" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCTEParsing(t *testing.T) { + t.Run("Simple CTE is parsed correctly", func(t *testing.T) { + query := "WITH cte AS (SELECT id, name FROM users) SELECT * FROM cte" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel, ok := stmt.(*sqlparser.Select) + require.True(t, ok, "Statement should be a SELECT") + + // Check that With clause exists + require.NotNil(t, sel.With, "WITH clause should exist") + require.Len(t, sel.With.CTEs, 1, "Should have 1 CTE") + + // Check CTE name + cte := sel.With.CTEs[0] + assert.Equal(t, "cte", cte.Name.GetRawVal(), "CTE name should be 'cte'") + + // Check that CTE has a select statement + require.NotNil(t, cte.Select, "CTE should have a select statement") + }) + + t.Run("Multiple CTEs are parsed correctly", func(t *testing.T) { + query := "WITH a AS (SELECT 1 as x), b AS (SELECT 2 as y) SELECT * FROM a, b" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel, ok := stmt.(*sqlparser.Select) + require.True(t, ok, "Statement should be a SELECT") + + // Check that With clause exists + require.NotNil(t, sel.With, "WITH clause should exist") + require.Len(t, sel.With.CTEs, 2, "Should have 2 CTEs") + + // Check CTE names + assert.Equal(t, "a", sel.With.CTEs[0].Name.GetRawVal(), "First CTE name should be 'a'") + assert.Equal(t, "b", sel.With.CTEs[1].Name.GetRawVal(), "Second CTE name should be 'b'") + }) + + t.Run("Recursive CTE is parsed correctly", func(t *testing.T) { + query := "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 10) SELECT * FROM cte" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel, ok := stmt.(*sqlparser.Select) + require.True(t, ok, "Statement should be a SELECT") + + // Check that With clause exists with Recursive flag + require.NotNil(t, sel.With, "WITH clause should exist") + assert.True(t, sel.With.Recursive, "WITH clause should be RECURSIVE") + require.Len(t, sel.With.CTEs, 1, "Should have 1 CTE") + + // Check CTE name + assert.Equal(t, "cte", sel.With.CTEs[0].Name.GetRawVal(), "CTE name should be 'cte'") + }) + + t.Run("CTE with column aliases", func(t *testing.T) { + query := "WITH cte(col1, col2) AS (SELECT id, name FROM users) SELECT * FROM cte" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel, ok := stmt.(*sqlparser.Select) + require.True(t, ok, "Statement should be a SELECT") + + require.NotNil(t, sel.With, "WITH clause should exist") + require.Len(t, sel.With.CTEs, 1, "Should have 1 CTE") + + cte := sel.With.CTEs[0] + assert.Equal(t, "cte", cte.Name.GetRawVal(), "CTE name should be 'cte'") + + // Check column aliases if present + require.Len(t, cte.Columns, 2, "CTE should have 2 column aliases") + assert.Equal(t, "col1", cte.Columns[0].GetRawVal(), "First column alias should be 'col1'") + assert.Equal(t, "col2", cte.Columns[1].GetRawVal(), "Second column alias should be 'col2'") + }) + + t.Run("Nested CTEs - CTE referencing another CTE", func(t *testing.T) { + query := "WITH a AS (SELECT 1 as x), b AS (SELECT x * 2 as y FROM a) SELECT * FROM b" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel, ok := stmt.(*sqlparser.Select) + require.True(t, ok, "Statement should be a SELECT") + + require.NotNil(t, sel.With, "WITH clause should exist") + require.Len(t, sel.With.CTEs, 2, "Should have 2 CTEs") + }) +} + +func TestCTERegistry(t *testing.T) { + t.Run("CTE registry stores CTEs correctly", func(t *testing.T) { + registry := make(map[string]*sqlparser.CommonTableExpr) + + query := "WITH cte1 AS (SELECT 1), cte2 AS (SELECT 2) SELECT * FROM cte1, cte2" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.NotNil(t, sel.With) + + // Simulate what the visitor does - register CTEs + for _, cte := range sel.With.CTEs { + cteName := cte.Name.GetRawVal() + registry[cteName] = cte + } + + // Verify registry contents + assert.Len(t, registry, 2, "Registry should have 2 CTEs") + assert.Contains(t, registry, "cte1", "Registry should contain 'cte1'") + assert.Contains(t, registry, "cte2", "Registry should contain 'cte2'") + }) + + t.Run("CTE lookup works correctly", func(t *testing.T) { + registry := make(map[string]*sqlparser.CommonTableExpr) + + query := "WITH my_cte AS (SELECT id, name FROM users) SELECT * FROM my_cte" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.NotNil(t, sel.With) + + // Register the CTE + for _, cte := range sel.With.CTEs { + cteName := cte.Name.GetRawVal() + registry[cteName] = cte + } + + // Verify we can look up the CTE + _, isCTE := registry["my_cte"] + assert.True(t, isCTE, "'my_cte' should be found in registry") + + // Verify non-CTE names are not found + _, isNotCTE := registry["users"] + assert.False(t, isNotCTE, "'users' should not be found in registry") + }) +} + +func TestWindowFunctionParsing(t *testing.T) { + t.Run("Window function with OVER clause is parsed correctly", func(t *testing.T) { + query := "SELECT ROW_NUMBER() OVER (ORDER BY id) as row_num FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel, ok := stmt.(*sqlparser.Select) + require.True(t, ok, "Statement should be a SELECT") + + require.Len(t, sel.SelectExprs, 1, "Should have 1 select expression") + + aliased, ok := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + require.True(t, ok, "Select expression should be aliased") + + funcExpr, ok := aliased.Expr.(*sqlparser.FuncExpr) + require.True(t, ok, "Expression should be a FuncExpr") + + assert.Equal(t, "row_number", funcExpr.Name.Lowered(), "Function name should be 'row_number'") + assert.NotNil(t, funcExpr.Over, "FuncExpr should have OVER clause") + }) + + t.Run("Window function with PARTITION BY is parsed correctly", func(t *testing.T) { + query := "SELECT SUM(amount) OVER (PARTITION BY category ORDER BY date) as running_sum FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + aliased := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + funcExpr := aliased.Expr.(*sqlparser.FuncExpr) + + assert.Equal(t, "sum", funcExpr.Name.Lowered()) + require.NotNil(t, funcExpr.Over, "FuncExpr should have OVER clause") + + // Check partition by exists + require.NotNil(t, funcExpr.Over.PartitionBy, "OVER clause should have PARTITION BY") + }) + + t.Run("Window function with frame specification", func(t *testing.T) { + query := "SELECT SUM(value) OVER (ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as cumsum FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + aliased := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + funcExpr := aliased.Expr.(*sqlparser.FuncExpr) + + assert.NotNil(t, funcExpr.Over, "FuncExpr should have OVER clause") + }) + + t.Run("Multiple window functions in query", func(t *testing.T) { + query := "SELECT ROW_NUMBER() OVER (ORDER BY id) as rn, RANK() OVER (ORDER BY score DESC) as rank FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 2, "Should have 2 select expressions") + + // Check first window function + aliased1 := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + funcExpr1 := aliased1.Expr.(*sqlparser.FuncExpr) + assert.NotNil(t, funcExpr1.Over, "First FuncExpr should have OVER clause") + + // Check second window function + aliased2 := sel.SelectExprs[1].(*sqlparser.AliasedExpr) + funcExpr2 := aliased2.Expr.(*sqlparser.FuncExpr) + assert.NotNil(t, funcExpr2.Over, "Second FuncExpr should have OVER clause") + }) + + t.Run("Regular function without OVER clause", func(t *testing.T) { + query := "SELECT UPPER(name) as upper_name FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + aliased := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + funcExpr := aliased.Expr.(*sqlparser.FuncExpr) + + assert.Equal(t, "upper", funcExpr.Name.Lowered()) + assert.Nil(t, funcExpr.Over, "UPPER() should not have OVER clause") + }) +} diff --git a/internal/stackql/driver/window_cte_integration_test.go b/internal/stackql/driver/window_cte_integration_test.go new file mode 100644 index 00000000..74cd05dc --- /dev/null +++ b/internal/stackql/driver/window_cte_integration_test.go @@ -0,0 +1,221 @@ +package driver_test + +import ( + "bufio" + "strings" + "testing" + + . "github.com/stackql/stackql/internal/stackql/driver" + "github.com/stackql/stackql/internal/stackql/entryutil" + "github.com/stackql/stackql/internal/stackql/querysubmit" + "github.com/stackql/stackql/internal/stackql/responsehandler" + "github.com/stackql/stackql/internal/test/stackqltestutil" + "github.com/stackql/stackql/internal/test/testobjects" + + lrucache "github.com/stackql/stackql-parser/go/cache" +) + +//nolint:govet,lll // test file +func TestSelectComputeDisksWindowRowNumber(t *testing.T) { + runtimeCtx, err := stackqltestutil.GetRuntimeCtx(testobjects.GetGoogleProviderString(), "text", "TestSelectComputeDisksWindowRowNumber") + if err != nil { + t.Fatalf("Test failed: %v", err) + } + inputBundle, err := stackqltestutil.BuildInputBundle(*runtimeCtx) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + testSubject := func(t *testing.T, outFile *bufio.Writer) { + handlerCtx, err := entryutil.BuildHandlerContext(*runtimeCtx, strings.NewReader(""), lrucache.NewLRUCache(int64(runtimeCtx.QueryCacheSize)), inputBundle.WithStdOut(outFile), true) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + handlerCtx.SetQuery(testobjects.SelectGoogleComputeDisksWindowRowNumber) + dr, _ := NewStackQLDriver(handlerCtx) + querySubmitter := querysubmit.NewQuerySubmitter() + prepareErr := querySubmitter.PrepareQuery(handlerCtx) + if prepareErr != nil { + t.Fatalf("Test failed: %v", prepareErr) + } + response := querySubmitter.SubmitQuery() + responsehandler.HandleResponse(handlerCtx, response) + + dr.ProcessQuery(handlerCtx.GetRawQuery()) + } + + stackqltestutil.SetupSimpleSelectGoogleComputeDisks(t, 1) + stackqltestutil.RunCaptureTestAgainstFiles(t, testSubject, []string{testobjects.ExpectedSelectComputeDisksWindowRowNumber}) +} + +//nolint:govet,lll // test file +func TestSelectComputeDisksWindowRank(t *testing.T) { + runtimeCtx, err := stackqltestutil.GetRuntimeCtx(testobjects.GetGoogleProviderString(), "text", "TestSelectComputeDisksWindowRank") + if err != nil { + t.Fatalf("Test failed: %v", err) + } + inputBundle, err := stackqltestutil.BuildInputBundle(*runtimeCtx) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + testSubject := func(t *testing.T, outFile *bufio.Writer) { + handlerCtx, err := entryutil.BuildHandlerContext(*runtimeCtx, strings.NewReader(""), lrucache.NewLRUCache(int64(runtimeCtx.QueryCacheSize)), inputBundle.WithStdOut(outFile), true) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + handlerCtx.SetQuery(testobjects.SelectGoogleComputeDisksWindowRank) + dr, _ := NewStackQLDriver(handlerCtx) + querySubmitter := querysubmit.NewQuerySubmitter() + prepareErr := querySubmitter.PrepareQuery(handlerCtx) + if prepareErr != nil { + t.Fatalf("Test failed: %v", prepareErr) + } + response := querySubmitter.SubmitQuery() + responsehandler.HandleResponse(handlerCtx, response) + + dr.ProcessQuery(handlerCtx.GetRawQuery()) + } + + stackqltestutil.SetupSimpleSelectGoogleComputeDisks(t, 1) + stackqltestutil.RunCaptureTestAgainstFiles(t, testSubject, []string{testobjects.ExpectedSelectComputeDisksWindowRank}) +} + +//nolint:govet,lll // test file +func TestSelectComputeDisksWindowSum(t *testing.T) { + runtimeCtx, err := stackqltestutil.GetRuntimeCtx(testobjects.GetGoogleProviderString(), "text", "TestSelectComputeDisksWindowSum") + if err != nil { + t.Fatalf("Test failed: %v", err) + } + inputBundle, err := stackqltestutil.BuildInputBundle(*runtimeCtx) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + testSubject := func(t *testing.T, outFile *bufio.Writer) { + handlerCtx, err := entryutil.BuildHandlerContext(*runtimeCtx, strings.NewReader(""), lrucache.NewLRUCache(int64(runtimeCtx.QueryCacheSize)), inputBundle.WithStdOut(outFile), true) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + handlerCtx.SetQuery(testobjects.SelectGoogleComputeDisksWindowSum) + dr, _ := NewStackQLDriver(handlerCtx) + querySubmitter := querysubmit.NewQuerySubmitter() + prepareErr := querySubmitter.PrepareQuery(handlerCtx) + if prepareErr != nil { + t.Fatalf("Test failed: %v", prepareErr) + } + response := querySubmitter.SubmitQuery() + responsehandler.HandleResponse(handlerCtx, response) + + dr.ProcessQuery(handlerCtx.GetRawQuery()) + } + + stackqltestutil.SetupSimpleSelectGoogleComputeDisks(t, 1) + stackqltestutil.RunCaptureTestAgainstFiles(t, testSubject, []string{testobjects.ExpectedSelectComputeDisksWindowSum}) +} + +//nolint:govet,lll // test file +func TestSelectComputeDisksCTESimple(t *testing.T) { + runtimeCtx, err := stackqltestutil.GetRuntimeCtx(testobjects.GetGoogleProviderString(), "text", "TestSelectComputeDisksCTESimple") + if err != nil { + t.Fatalf("Test failed: %v", err) + } + inputBundle, err := stackqltestutil.BuildInputBundle(*runtimeCtx) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + testSubject := func(t *testing.T, outFile *bufio.Writer) { + handlerCtx, err := entryutil.BuildHandlerContext(*runtimeCtx, strings.NewReader(""), lrucache.NewLRUCache(int64(runtimeCtx.QueryCacheSize)), inputBundle.WithStdOut(outFile), true) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + handlerCtx.SetQuery(testobjects.SelectGoogleComputeDisksCTESimple) + dr, _ := NewStackQLDriver(handlerCtx) + querySubmitter := querysubmit.NewQuerySubmitter() + prepareErr := querySubmitter.PrepareQuery(handlerCtx) + if prepareErr != nil { + t.Fatalf("Test failed: %v", prepareErr) + } + response := querySubmitter.SubmitQuery() + responsehandler.HandleResponse(handlerCtx, response) + + dr.ProcessQuery(handlerCtx.GetRawQuery()) + } + + stackqltestutil.SetupSimpleSelectGoogleComputeDisks(t, 1) + stackqltestutil.RunCaptureTestAgainstFiles(t, testSubject, []string{testobjects.ExpectedSelectComputeDisksCTESimple}) +} + +//nolint:govet,lll // test file +func TestSelectComputeDisksCTEWithAgg(t *testing.T) { + runtimeCtx, err := stackqltestutil.GetRuntimeCtx(testobjects.GetGoogleProviderString(), "text", "TestSelectComputeDisksCTEWithAgg") + if err != nil { + t.Fatalf("Test failed: %v", err) + } + inputBundle, err := stackqltestutil.BuildInputBundle(*runtimeCtx) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + testSubject := func(t *testing.T, outFile *bufio.Writer) { + handlerCtx, err := entryutil.BuildHandlerContext(*runtimeCtx, strings.NewReader(""), lrucache.NewLRUCache(int64(runtimeCtx.QueryCacheSize)), inputBundle.WithStdOut(outFile), true) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + handlerCtx.SetQuery(testobjects.SelectGoogleComputeDisksCTEWithAgg) + dr, _ := NewStackQLDriver(handlerCtx) + querySubmitter := querysubmit.NewQuerySubmitter() + prepareErr := querySubmitter.PrepareQuery(handlerCtx) + if prepareErr != nil { + t.Fatalf("Test failed: %v", prepareErr) + } + response := querySubmitter.SubmitQuery() + responsehandler.HandleResponse(handlerCtx, response) + + dr.ProcessQuery(handlerCtx.GetRawQuery()) + } + + stackqltestutil.SetupSimpleSelectGoogleComputeDisks(t, 1) + stackqltestutil.RunCaptureTestAgainstFiles(t, testSubject, []string{testobjects.ExpectedSelectComputeDisksCTEWithAgg}) +} + +//nolint:govet,lll // test file +func TestSelectComputeDisksCTEMultiple(t *testing.T) { + runtimeCtx, err := stackqltestutil.GetRuntimeCtx(testobjects.GetGoogleProviderString(), "text", "TestSelectComputeDisksCTEMultiple") + if err != nil { + t.Fatalf("Test failed: %v", err) + } + inputBundle, err := stackqltestutil.BuildInputBundle(*runtimeCtx) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + testSubject := func(t *testing.T, outFile *bufio.Writer) { + handlerCtx, err := entryutil.BuildHandlerContext(*runtimeCtx, strings.NewReader(""), lrucache.NewLRUCache(int64(runtimeCtx.QueryCacheSize)), inputBundle.WithStdOut(outFile), true) + if err != nil { + t.Fatalf("Test failed: %v", err) + } + + handlerCtx.SetQuery(testobjects.SelectGoogleComputeDisksCTEMultiple) + dr, _ := NewStackQLDriver(handlerCtx) + querySubmitter := querysubmit.NewQuerySubmitter() + prepareErr := querySubmitter.PrepareQuery(handlerCtx) + if prepareErr != nil { + t.Fatalf("Test failed: %v", prepareErr) + } + response := querySubmitter.SubmitQuery() + responsehandler.HandleResponse(handlerCtx, response) + + dr.ProcessQuery(handlerCtx.GetRawQuery()) + } + + // Multiple CTEs need two API calls (one for each CTE that queries the provider) + stackqltestutil.SetupSimpleSelectGoogleComputeDisks(t, 2) + stackqltestutil.RunCaptureTestAgainstFiles(t, testSubject, []string{testobjects.ExpectedSelectComputeDisksCTEMultiple}) +} diff --git a/internal/stackql/parserutil/parser_util.go b/internal/stackql/parserutil/parser_util.go index f94f03d4..49613418 100644 --- a/internal/stackql/parserutil/parser_util.go +++ b/internal/stackql/parserutil/parser_util.go @@ -638,6 +638,10 @@ func inferColNameFromExpr( retVal.IsAggregateExpr = true retVal.Type = aggCol.getReturnType() } + // Window functions (with OVER clause) are also treated as aggregate expressions + if expr.Over != nil { + retVal.IsAggregateExpr = true + } if len(funcNameLowered) >= 4 && funcNameLowered[0:4] == "json" { decoratedColumn := strings.ReplaceAll(retVal.Name, `\"`, `"`) retVal.DecoratedColumn = getDecoratedColRendition(decoratedColumn, alias) diff --git a/internal/stackql/parserutil/parser_util_test.go b/internal/stackql/parserutil/parser_util_test.go new file mode 100644 index 00000000..acab1879 --- /dev/null +++ b/internal/stackql/parserutil/parser_util_test.go @@ -0,0 +1,184 @@ +package parserutil_test + +import ( + "testing" + + "github.com/stackql/stackql-parser/go/vt/sqlparser" + "github.com/stackql/stackql/internal/stackql/parserutil" + "github.com/stackql/stackql/pkg/astformat" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInferColNameFromExpr_WindowFunctions(t *testing.T) { + formatter := astformat.DefaultSelectExprsFormatter + + t.Run("ROW_NUMBER with OVER clause is marked as aggregate", func(t *testing.T) { + query := "SELECT ROW_NUMBER() OVER (ORDER BY id) as row_num FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 1) + + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + colHandle, err := parserutil.InferColNameFromExpr(aliasedExpr, formatter) + require.NoError(t, err) + + assert.True(t, colHandle.IsAggregateExpr, "ROW_NUMBER() with OVER should be marked as aggregate expression") + assert.Equal(t, "row_num", colHandle.Alias) + }) + + t.Run("SUM with OVER PARTITION BY is marked as aggregate", func(t *testing.T) { + query := "SELECT SUM(amount) OVER (PARTITION BY category) as running_sum FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 1) + + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + colHandle, err := parserutil.InferColNameFromExpr(aliasedExpr, formatter) + require.NoError(t, err) + + assert.True(t, colHandle.IsAggregateExpr, "SUM() with OVER PARTITION BY should be marked as aggregate expression") + assert.Equal(t, "running_sum", colHandle.Alias) + }) + + t.Run("RANK with OVER ORDER BY is marked as aggregate", func(t *testing.T) { + query := "SELECT RANK() OVER (ORDER BY score DESC) as ranking FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 1) + + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + colHandle, err := parserutil.InferColNameFromExpr(aliasedExpr, formatter) + require.NoError(t, err) + + assert.True(t, colHandle.IsAggregateExpr, "RANK() with OVER should be marked as aggregate expression") + assert.Equal(t, "ranking", colHandle.Alias) + }) + + t.Run("DENSE_RANK with OVER is marked as aggregate", func(t *testing.T) { + query := "SELECT DENSE_RANK() OVER (PARTITION BY dept ORDER BY salary DESC) as dense_rank FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 1) + + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + colHandle, err := parserutil.InferColNameFromExpr(aliasedExpr, formatter) + require.NoError(t, err) + + assert.True(t, colHandle.IsAggregateExpr, "DENSE_RANK() with OVER should be marked as aggregate expression") + }) + + t.Run("NTILE with OVER is marked as aggregate", func(t *testing.T) { + query := "SELECT NTILE(4) OVER (ORDER BY id) as quartile FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 1) + + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + colHandle, err := parserutil.InferColNameFromExpr(aliasedExpr, formatter) + require.NoError(t, err) + + assert.True(t, colHandle.IsAggregateExpr, "NTILE() with OVER should be marked as aggregate expression") + }) + + t.Run("LAG with OVER is marked as aggregate", func(t *testing.T) { + query := "SELECT LAG(value, 1) OVER (ORDER BY date) as prev_value FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 1) + + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + colHandle, err := parserutil.InferColNameFromExpr(aliasedExpr, formatter) + require.NoError(t, err) + + assert.True(t, colHandle.IsAggregateExpr, "LAG() with OVER should be marked as aggregate expression") + }) + + t.Run("LEAD with OVER is marked as aggregate", func(t *testing.T) { + query := "SELECT LEAD(value, 1) OVER (ORDER BY date) as next_value FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 1) + + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + colHandle, err := parserutil.InferColNameFromExpr(aliasedExpr, formatter) + require.NoError(t, err) + + assert.True(t, colHandle.IsAggregateExpr, "LEAD() with OVER should be marked as aggregate expression") + }) + + t.Run("FIRST_VALUE with OVER is marked as aggregate", func(t *testing.T) { + query := "SELECT FIRST_VALUE(name) OVER (PARTITION BY category ORDER BY date) as first_name FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 1) + + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + colHandle, err := parserutil.InferColNameFromExpr(aliasedExpr, formatter) + require.NoError(t, err) + + assert.True(t, colHandle.IsAggregateExpr, "FIRST_VALUE() with OVER should be marked as aggregate expression") + }) + + t.Run("LAST_VALUE with OVER is marked as aggregate", func(t *testing.T) { + query := "SELECT LAST_VALUE(name) OVER (PARTITION BY category ORDER BY date) as last_name FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 1) + + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + colHandle, err := parserutil.InferColNameFromExpr(aliasedExpr, formatter) + require.NoError(t, err) + + assert.True(t, colHandle.IsAggregateExpr, "LAST_VALUE() with OVER should be marked as aggregate expression") + }) + + t.Run("Regular aggregate function without OVER", func(t *testing.T) { + query := "SELECT COUNT(*) as total FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 1) + + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + colHandle, err := parserutil.InferColNameFromExpr(aliasedExpr, formatter) + require.NoError(t, err) + + assert.True(t, colHandle.IsAggregateExpr, "COUNT() should be marked as aggregate expression") + assert.Equal(t, "total", colHandle.Alias) + }) + + t.Run("Regular function without OVER is not aggregate", func(t *testing.T) { + query := "SELECT UPPER(name) as upper_name FROM t" + stmt, err := sqlparser.Parse(query) + require.NoError(t, err) + + sel := stmt.(*sqlparser.Select) + require.Len(t, sel.SelectExprs, 1) + + aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + colHandle, err := parserutil.InferColNameFromExpr(aliasedExpr, formatter) + require.NoError(t, err) + + assert.False(t, colHandle.IsAggregateExpr, "UPPER() should not be marked as aggregate expression") + }) +} diff --git a/internal/test/testobjects/expected.go b/internal/test/testobjects/expected.go index c547f4a1..d5a83cfb 100644 --- a/internal/test/testobjects/expected.go +++ b/internal/test/testobjects/expected.go @@ -42,4 +42,14 @@ const ( ExpectedSelectComputeDisksAggPaginatedSizeTotal string = "test/assets/expected/aggregated-select/google/disks-paginated/text/disks-sizeGb-total-sum.csv" ExpectedSelectComputeDisksAggPaginatedStringTotal string = "test/assets/expected/aggregated-select/google/disks-paginated/text/disks-total-string-agg.csv" ExpectedSelectExecOrgGetIamPolicyAgg string = "test/assets/expected/aggregated-select/google/cloudresourcemanager/select-exec-getiampolicy-agg.csv" + + // Window function expected outputs. + ExpectedSelectComputeDisksWindowRowNumber string = "test/assets/expected/window-select/google/disks/text/disks-window-row-number.csv" + ExpectedSelectComputeDisksWindowRank string = "test/assets/expected/window-select/google/disks/text/disks-window-rank.csv" + ExpectedSelectComputeDisksWindowSum string = "test/assets/expected/window-select/google/disks/text/disks-window-sum.csv" + + // CTE expected outputs. + ExpectedSelectComputeDisksCTESimple string = "test/assets/expected/cte-select/google/disks/text/disks-cte-simple.csv" + ExpectedSelectComputeDisksCTEWithAgg string = "test/assets/expected/cte-select/google/disks/text/disks-cte-with-agg.csv" + ExpectedSelectComputeDisksCTEMultiple string = "test/assets/expected/cte-select/google/disks/text/disks-cte-multiple.csv" ) diff --git a/internal/test/testobjects/input.go b/internal/test/testobjects/input.go index 1d9e5d1b..cc6e8197 100644 --- a/internal/test/testobjects/input.go +++ b/internal/test/testobjects/input.go @@ -55,6 +55,16 @@ const ( SelectGoogleComputeDisksAggOrderSizeDesc string = `select sizeGb, COUNT(1) as cc from google.compute.disks where zone = 'australia-southeast1-b' AND /* */ project = 'testing-project' GROUP BY sizeGb ORDER BY sizeGb DESC;` SelectGoogleComputeDisksAggSizeTotal string = `select sum(cast(sizeGb as unsigned)) - 10 as cc from google.compute.disks where zone = 'australia-southeast1-b' AND /* */ project = 'testing-project';` SelectGoogleComputeDisksAggStringTotal string = `select group_concat(substr(name, 0, 5)) || ' lalala' as cc from google.compute.disks where zone = 'australia-southeast1-b' AND /* */ project = 'testing-project';` + + // Window function test queries. + SelectGoogleComputeDisksWindowRowNumber string = `select name, sizeGb, ROW_NUMBER() OVER (ORDER BY name) as row_num from google.compute.disks where zone = 'australia-southeast1-b' AND project = 'testing-project' ORDER BY name;` + SelectGoogleComputeDisksWindowRank string = `select name, sizeGb, RANK() OVER (ORDER BY sizeGb) as size_rank from google.compute.disks where zone = 'australia-southeast1-b' AND project = 'testing-project' ORDER BY name;` + SelectGoogleComputeDisksWindowSum string = `select name, sizeGb, SUM(cast(sizeGb as int)) OVER (ORDER BY name) as running_total from google.compute.disks where zone = 'australia-southeast1-b' AND project = 'testing-project' ORDER BY name;` + + // CTE test queries. + SelectGoogleComputeDisksCTESimple string = `WITH disk_cte AS (SELECT name, sizeGb FROM google.compute.disks WHERE zone = 'australia-southeast1-b' AND project = 'testing-project') SELECT name, sizeGb FROM disk_cte ORDER BY name;` + SelectGoogleComputeDisksCTEWithAgg string = `WITH disk_cte AS (SELECT name, sizeGb FROM google.compute.disks WHERE zone = 'australia-southeast1-b' AND project = 'testing-project') SELECT COUNT(*) as disk_count FROM disk_cte;` + SelectGoogleComputeDisksCTEMultiple string = `WITH small_disks AS (SELECT name, sizeGb FROM google.compute.disks WHERE zone = 'australia-southeast1-b' AND project = 'testing-project' AND cast(sizeGb as int) <= 10), large_disks AS (SELECT name, sizeGb FROM google.compute.disks WHERE zone = 'australia-southeast1-b' AND project = 'testing-project' AND cast(sizeGb as int) > 10) SELECT 'small' as category, COUNT(*) as cnt FROM small_disks UNION ALL SELECT 'large' as category, COUNT(*) as cnt FROM large_disks;` ) func GetGoogleProviderString() string { diff --git a/test/assets/expected/cte-select/google/disks/text/disks-cte-multiple.csv b/test/assets/expected/cte-select/google/disks/text/disks-cte-multiple.csv new file mode 100644 index 00000000..9acd3546 --- /dev/null +++ b/test/assets/expected/cte-select/google/disks/text/disks-cte-multiple.csv @@ -0,0 +1,3 @@ +category,cnt +small,3 +large,3 diff --git a/test/assets/expected/cte-select/google/disks/text/disks-cte-simple.csv b/test/assets/expected/cte-select/google/disks/text/disks-cte-simple.csv new file mode 100644 index 00000000..16305d30 --- /dev/null +++ b/test/assets/expected/cte-select/google/disks/text/disks-cte-simple.csv @@ -0,0 +1,7 @@ +name,sizeGb +demo-disk-qq1,10 +demo-disk-qq2,10 +demo-disk-xx2,10 +demo-disk-xx3,20 +demo-disk-xx4,30 +demo-disk-xx5,40 diff --git a/test/assets/expected/cte-select/google/disks/text/disks-cte-with-agg.csv b/test/assets/expected/cte-select/google/disks/text/disks-cte-with-agg.csv new file mode 100644 index 00000000..52e37c19 --- /dev/null +++ b/test/assets/expected/cte-select/google/disks/text/disks-cte-with-agg.csv @@ -0,0 +1,2 @@ +disk_count +6 diff --git a/test/assets/expected/window-select/google/disks/text/disks-window-rank.csv b/test/assets/expected/window-select/google/disks/text/disks-window-rank.csv new file mode 100644 index 00000000..7e938b6a --- /dev/null +++ b/test/assets/expected/window-select/google/disks/text/disks-window-rank.csv @@ -0,0 +1,7 @@ +name,sizeGb,size_rank +demo-disk-qq1,10,1 +demo-disk-qq2,10,1 +demo-disk-xx2,10,1 +demo-disk-xx3,20,4 +demo-disk-xx4,30,5 +demo-disk-xx5,40,6 diff --git a/test/assets/expected/window-select/google/disks/text/disks-window-row-number.csv b/test/assets/expected/window-select/google/disks/text/disks-window-row-number.csv new file mode 100644 index 00000000..b7320f0f --- /dev/null +++ b/test/assets/expected/window-select/google/disks/text/disks-window-row-number.csv @@ -0,0 +1,7 @@ +name,sizeGb,row_num +demo-disk-qq1,10,1 +demo-disk-qq2,10,2 +demo-disk-xx2,10,3 +demo-disk-xx3,20,4 +demo-disk-xx4,30,5 +demo-disk-xx5,40,6 diff --git a/test/assets/expected/window-select/google/disks/text/disks-window-sum.csv b/test/assets/expected/window-select/google/disks/text/disks-window-sum.csv new file mode 100644 index 00000000..1d6b4637 --- /dev/null +++ b/test/assets/expected/window-select/google/disks/text/disks-window-sum.csv @@ -0,0 +1,7 @@ +name,sizeGb,running_total +demo-disk-qq1,10,10 +demo-disk-qq2,10,20 +demo-disk-xx2,10,30 +demo-disk-xx3,20,50 +demo-disk-xx4,30,80 +demo-disk-xx5,40,120