Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
The diff you're trying to view is too large. We only load the first 3000 changed files.
208 changes: 208 additions & 0 deletions cmd/regenerate-ast/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
package main

import (
"context"
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/sqlc-dev/doubleclick/parser"
)

func main() {
testName := flag.String("test", "", "Single test directory name to process (if empty, process all)")
dryRun := flag.Bool("dry-run", false, "Print what would be done without making changes")
flag.Parse()

testdataDir := "parser/testdata"

if *testName != "" {
// Process single test
if err := processTest(filepath.Join(testdataDir, *testName), *dryRun); err != nil {
fmt.Fprintf(os.Stderr, "Error processing %s: %v\n", *testName, err)
os.Exit(1)
}
return
}

// Process all tests
entries, err := os.ReadDir(testdataDir)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading testdata: %v\n", err)
os.Exit(1)
}

var processed, skipped, errors int
for _, entry := range entries {
if !entry.IsDir() {
continue
}
testDir := filepath.Join(testdataDir, entry.Name())
if err := processTest(testDir, *dryRun); err != nil {
if strings.Contains(err.Error(), "no statements found") {
skipped++
} else {
fmt.Fprintf(os.Stderr, "Error processing %s: %v\n", entry.Name(), err)
errors++
}
} else {
processed++
}
}

fmt.Printf("\nProcessed: %d, Skipped: %d, Errors: %d\n", processed, skipped, errors)
if errors > 0 {
os.Exit(1)
}
}

func processTest(testDir string, dryRun bool) error {
queryPath := filepath.Join(testDir, "query.sql")
queryBytes, err := os.ReadFile(queryPath)
if err != nil {
return fmt.Errorf("reading query.sql: %w", err)
}

statements := splitStatements(string(queryBytes))
if len(statements) == 0 {
return fmt.Errorf("no statements found")
}

testName := filepath.Base(testDir)
goldenDir := filepath.Join(testDir, "golden", "ast")

if dryRun {
fmt.Printf("Would process %s (%d statements) -> %s/\n", testName, len(statements), goldenDir)
for i, stmt := range statements {
fmt.Printf(" [%d] %s -> stmt_%04d.json\n", i+1, truncate(stmt, 60), i+1)
}
return nil
}

// Create golden/ast directory
if err := os.MkdirAll(goldenDir, 0755); err != nil {
return fmt.Errorf("creating golden directory: %w", err)
}

var stmtErrors []string
for i, stmt := range statements {
stmtNum := i + 1

// Parse the statement
stmts, parseErr := parser.Parse(context.Background(), strings.NewReader(stmt))
if len(stmts) == 0 {
stmtErrors = append(stmtErrors, fmt.Sprintf("stmt %d: parse error: %v", stmtNum, parseErr))
continue
}

// Marshal to pretty JSON
jsonBytes, err := json.MarshalIndent(stmts[0], "", " ")
if err != nil {
stmtErrors = append(stmtErrors, fmt.Sprintf("stmt %d: json marshal error: %v", stmtNum, err))
continue
}

// Write to golden file
outputPath := filepath.Join(goldenDir, fmt.Sprintf("stmt_%04d.json", stmtNum))
if err := os.WriteFile(outputPath, append(jsonBytes, '\n'), 0644); err != nil {
return fmt.Errorf("writing %s: %w", outputPath, err)
}
}

// Print summary
if len(stmtErrors) > 0 {
fmt.Printf("%s: %d stmts, %d errors\n", testName, len(statements), len(stmtErrors))
for _, e := range stmtErrors {
fmt.Printf(" %s\n", e)
}
} else {
fmt.Printf("%s: %d stmts OK\n", testName, len(statements))
}

return nil
}

// splitStatements splits SQL content into individual statements.
func splitStatements(content string) []string {
var statements []string
var current strings.Builder

lines := strings.Split(content, "\n")
for _, line := range lines {
trimmed := strings.TrimSpace(line)

// Skip empty lines and full-line comments
if trimmed == "" || strings.HasPrefix(trimmed, "--") {
continue
}

// Remove inline comments
if idx := findCommentStart(trimmed); idx >= 0 {
trimmed = strings.TrimSpace(trimmed[:idx])
if trimmed == "" {
continue
}
}

if current.Len() > 0 {
current.WriteString(" ")
}
current.WriteString(trimmed)

if strings.HasSuffix(trimmed, ";") {
stmt := strings.TrimSpace(current.String())
if stmt != "" && stmt != ";" {
statements = append(statements, stmt)
}
current.Reset()
}
}

if current.Len() > 0 {
stmt := strings.TrimSpace(current.String())
if stmt != "" {
statements = append(statements, stmt)
}
}

return statements
}

func findCommentStart(line string) int {
inString := false
var stringChar byte
for i := 0; i < len(line); i++ {
c := line[i]
if inString {
if c == '\\' && i+1 < len(line) {
i++
continue
}
if c == stringChar {
inString = false
}
} else {
if c == '\'' || c == '"' || c == '`' {
inString = true
stringChar = c
} else if c == '-' && i+1 < len(line) && line[i+1] == '-' {
if i+2 >= len(line) || line[i+2] == ' ' || line[i+2] == '\t' {
return i
}
}
}
}
return -1
}

func truncate(s string, n int) string {
s = strings.ReplaceAll(s, "\n", " ")
s = strings.Join(strings.Fields(s), " ")
if len(s) <= n {
return s
}
return s[:n-3] + "..."
}
22 changes: 22 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import (
// Use with: go test ./parser -check-explain -v
var checkExplain = flag.Bool("check-explain", false, "Run skipped explain_todo tests to see which ones now pass")

// skipAST disables AST golden file verification.
// Use with: go test ./parser -skip-ast to skip AST verification
var skipAST = flag.Bool("skip-ast", false, "Skip AST golden file verification")

// testMetadata holds optional metadata for a test case
type testMetadata struct {
ExplainTodo map[string]bool `json:"explain_todo,omitempty"` // map of stmtN -> true to skip specific statements
Expand Down Expand Up @@ -332,6 +336,24 @@ func TestParser(t *testing.T) {
}
}

// Check AST golden file unless -skip-ast is set
if !*skipAST {
astGoldenPath := filepath.Join(testDir, "golden", "ast", fmt.Sprintf("stmt_%04d.json", stmtIndex))
if expectedASTBytes, err := os.ReadFile(astGoldenPath); err == nil {
// Marshal actual AST to JSON
actualASTBytes, err := json.MarshalIndent(stmts[0], "", " ")
if err != nil {
t.Errorf("Failed to marshal AST to JSON: %v", err)
} else {
expectedAST := strings.TrimSpace(string(expectedASTBytes))
actualAST := strings.TrimSpace(string(actualASTBytes))
if expectedAST != actualAST {
t.Errorf("AST mismatch for %s\nExpected:\n%s\n\nGot:\n%s", astGoldenPath, expectedAST, actualAST)
}
}
}
}

})
}
})
Expand Down
25 changes: 25 additions & 0 deletions parser/testdata/00002_system_numbers/golden/ast/stmt_0002.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"selects": [
{
"columns": [
{}
],
"from": {
"tables": [
{
"table": {
"table": {
"database": "system",
"table": "numbers"
}
}
}
]
},
"limit": {
"type": "Integer",
"value": 3
}
}
]
}
43 changes: 43 additions & 0 deletions parser/testdata/00002_system_numbers/golden/ast/stmt_0003.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"selects": [
{
"columns": [
{
"parts": [
"sys_num",
"number"
]
}
],
"from": {
"tables": [
{
"table": {
"table": {
"database": "system",
"table": "numbers"
},
"alias": "sys_num"
}
}
]
},
"where": {
"left": {
"parts": [
"number"
]
},
"op": "\u003e",
"right": {
"type": "Integer",
"value": 2
}
},
"limit": {
"type": "Integer",
"value": 2
}
}
]
}
41 changes: 41 additions & 0 deletions parser/testdata/00002_system_numbers/golden/ast/stmt_0004.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"selects": [
{
"columns": [
{
"parts": [
"number"
]
}
],
"from": {
"tables": [
{
"table": {
"table": {
"database": "system",
"table": "numbers"
}
}
}
]
},
"where": {
"left": {
"parts": [
"number"
]
},
"op": "\u003e=",
"right": {
"type": "Integer",
"value": 5
}
},
"limit": {
"type": "Integer",
"value": 2
}
}
]
}
37 changes: 37 additions & 0 deletions parser/testdata/00002_system_numbers/golden/ast/stmt_0005.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"selects": [
{
"columns": [
{}
],
"from": {
"tables": [
{
"table": {
"table": {
"database": "system",
"table": "numbers"
}
}
}
]
},
"where": {
"left": {
"parts": [
"number"
]
},
"op": "==",
"right": {
"type": "Integer",
"value": 7
}
},
"limit": {
"type": "Integer",
"value": 1
}
}
]
}
Loading