Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spanner/spansql): Add tests and parser support for THEN RETURN #6515

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions spanner/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,14 @@ NULL-able STRUCT values.
# DML and Partitioned DML

Spanner supports DML statements like INSERT, UPDATE and DELETE. Use
ReadWriteTransaction.Update to run DML statements. It returns the number of rows
affected. (You can call use ReadWriteTransaction.Query with a DML statement. The
first call to Next on the resulting RowIterator will return iterator.Done, and
the RowCount field of the iterator will hold the number of affected rows.)
ReadWriteTransaction.Update to run simple DML statements and
ReadWriteTransaction.Query to run DML statements containing "THEN RETURN"
rajatbhatta marked this conversation as resolved.
Show resolved Hide resolved
(GoogleSQL dialect) or "RETURNING" (PostgreSQL dialect) clauses.
ReadWriteTransaction.Update returns the number of rows affected. When using
ReadWriteTransaction.Query the number of rows affected is available via
RowIterator.RowCount after RowIterator.Next returns iterator.Done. (You can call
ReadWriteTransaction.Query with a simple DML statement as well. In this case
the first call to Next on the resulting RowIterator will return iterator.Done.)

For large databases, it may be more efficient to partition the DML statement.
Use client.PartitionedUpdate to run a DML statement in this way. Not all DML
Expand Down
208 changes: 208 additions & 0 deletions spanner/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2953,6 +2953,214 @@ func TestIntegration_DML(t *testing.T) {
}
}

func TestIntegration_DMLReturning(t *testing.T) {
c2nes marked this conversation as resolved.
Show resolved Hide resolved
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements])
defer cleanup()

query := func(ctx context.Context, tx *ReadWriteTransaction, query string, expectedRows int) ([]*Row, int64) {
iter := tx.Query(ctx, Statement{
SQL: query,
})
defer iter.Stop()
var rows []*Row
for len(rows) < expectedRows {
row, err := iter.Next()
if err != nil {
t.Fatalf("wanted row, got error %v", err)
}
rows = append(rows, row)
}
if _, err := iter.Next(); err != iterator.Done {
if err == nil {
t.Fatalf("got more rows than expected, wanted %d rows", expectedRows)
} else {
t.Fatalf("wanted iterator.Done, got error %v", err)
}
}
return rows, iter.RowCount
}

queries := []string{
`INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) THEN RETURN Nickname, Balance`,
`INSERT INTO Accounts (AccountId, Balance) VALUES (2, 20) THEN RETURN Balance + 1 AS BalancePlusOne`,
`UPDATE Accounts SET Balance = Balance + 100 WHERE TRUE THEN RETURN AccountId, Balance`,
`DELETE FROM Accounts WHERE AccountId = 1 THEN RETURN *`,
}

if testDialect == adminpb.DatabaseDialect_POSTGRESQL {
queries = []string{
`INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) RETURNING Nickname, Balance`,
`INSERT INTO Accounts (AccountId, Balance) VALUES (2, 20) RETURNING Balance + 1 AS BalancePlusOne`,
`UPDATE Accounts SET Balance = Balance + 100 WHERE TRUE RETURNING AccountId, Balance`,
`DELETE FROM Accounts WHERE AccountId = 1 RETURNING *`,
}
}
_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
// Insert first account, returning Nickname and Balance.
rows, _ := query(ctx, tx, queries[0], 1)
var nick *string
var bal int64
if err := rows[0].Columns(&nick, &bal); err != nil {
return err
}
if want := (*string)(nil); want != nick {
t.Errorf("got %v, want %v", nick, want)
}
if want := int64(10); want != bal {
t.Errorf("got %d, want %d", bal, want)
}

// Insert second account, returning Balance+1
rows, _ = query(ctx, tx, queries[1], 1)
var balancePlusOne int64
if err := rows[0].Columns(&balancePlusOne); err != nil {
return err
}
if want := int64(21); want != balancePlusOne {
t.Errorf("got %q, want %q", balancePlusOne, want)
}

// Update both accounts, returning AccountId and Balance for each.
rows, _ = query(ctx, tx, queries[2], 2)
balances := make(map[int64]int64)
for _, row := range rows {
var acctID int64
var bal int64
if err := row.Columns(&acctID, &bal); err != nil {
return err
}
balances[acctID] = bal
}
if want := map[int64]int64{1: 110, 2: 120}; !testEqual(want, balances) {
t.Errorf("got %#v, want %#v", balances, want)
}

return nil
})
if err != nil {
t.Fatal(err)
}

// Delete an account, returning all columns
_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
rows, modifiedRowCount := query(ctx, tx, queries[3], 1)

if want := int64(1); modifiedRowCount != want {
t.Fatalf("got %v, want %v", modifiedRowCount, want)
}

var acctID int64
var nick *string
var bal int64
if err := rows[0].Columns(&acctID, &nick, &bal); err != nil {
t.Fatal(err)
}

if want := int64(1); want != acctID {
t.Fatalf("got %v, want %v", acctID, want)
}
if want := (*string)(nil); want != nick {
t.Fatalf("got %v, want %v", nick, want)
}
if want := int64(110); want != bal {
t.Fatalf("got %v, want %v", bal, want)
}

return nil
})
if err != nil {
t.Fatal(err)
}
}

func TestIntegration_DMLReturning_ViaUpdate(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements])
defer cleanup()

query := `INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) THEN RETURN Nickname, Balance`
if testDialect == adminpb.DatabaseDialect_POSTGRESQL {
query = `INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) RETURNING Nickname, Balance`
}
_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
// DML statements with returning clauses, executed via Update should
// behave as-if the returning clause were not given.
rowCount, err := tx.Update(ctx, NewStatement(query))
if err != nil {
return err
}

if want := int64(1); want != rowCount {
t.Errorf("got %v, want %v", rowCount, want)
}

return nil
})
if err != nil {
t.Fatal(err)
}
}

func TestIntegration_DMLReturning_ViaBatchUpdate(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements])
defer cleanup()

queries := []string{
`INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) THEN RETURN Nickname, Balance`,
`INSERT INTO Accounts (AccountId, Balance) VALUES (2, 20) THEN RETURN Balance + 1 AS BalancePlusOne`,
`SELECT Balance FROM Accounts WHERE AccountId = 2`,
}

if testDialect == adminpb.DatabaseDialect_POSTGRESQL {
queries = []string{
`INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) RETURNING Nickname, Balance`,
`INSERT INTO Accounts (AccountId, Balance) VALUES (2, 20) RETURNING Balance + 1 AS BalancePlusOne`,
`SELECT Balance FROM Accounts WHERE AccountId = 2`,
}
}

_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
// BatchUpdate should effectively ignore the returning clauses.
rowCounts, err := tx.BatchUpdate(ctx, []Statement{NewStatement(queries[0]), NewStatement(queries[1])})
if err != nil {
return err
}

if want := []int64{1, 1}; !testEqual(want, rowCounts) {
t.Errorf("got %v, want %v", rowCounts, want)
}

iter := tx.Query(ctx, NewStatement(queries[2]))
row, err := iter.Next()
if err != nil {
return err
}
var balance int64
if err := row.Columns(&balance); err != nil {
return err
}
if want := int64(20); balance != want {
t.Errorf("got %v, want %v", balance, want)
}

return nil
})
if err != nil {
t.Fatal(err)
}
}

func TestIntegration_StructParametersBind(t *testing.T) {
t.Parallel()
skipUnsupportedPGTest(t)
Expand Down
1 change: 1 addition & 0 deletions spanner/spannertest/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ by ascending esotericism:
- SELECT HAVING
- more literal types
- DEFAULT
- THEN RETURN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are supporting it through this PR, right? I guess we want to say PG.Returning is not supported.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I did not intend to add this support myself. We will be adding support to the standard Spanner Emulator however.

- expressions that return null for generated columns
- generated columns referencing other generated columns
- checking dependencies on a generated column before deleting a column
Expand Down
39 changes: 34 additions & 5 deletions spanner/spansql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -1482,16 +1482,19 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) {
/*
DELETE [FROM] target_name [[AS] alias]
WHERE condition
THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...]

UPDATE target_name [[AS] alias]
SET update_item [, ...]
WHERE condition
THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...]

update_item: path_expression = expression | path_expression = DEFAULT

INSERT [INTO] target_name
(column_name_1 [, ..., column_name_n] )
input
THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...]

input:
VALUES (row_1_column_1_expr [, ..., row_1_column_n_expr ] )
Expand All @@ -1515,9 +1518,15 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) {
if err != nil {
return nil, err
}
list, aliases, err := p.parseThenReturn()
if err != nil {
return nil, err
}
return &Delete{
Table: tname,
Where: where,
Table: tname,
Where: where,
Return: list,
ReturnAliases: aliases,
}, nil
}

Expand Down Expand Up @@ -1552,6 +1561,12 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) {
return nil, err
}
u.Where = where
list, aliases, err := p.parseThenReturn()
if err != nil {
return nil, err
}
u.Return = list
u.ReturnAliases = aliases
return u, nil
}

Expand Down Expand Up @@ -1588,16 +1603,30 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) {
}
}

list, aliases, err := p.parseThenReturn()
if err != nil {
return nil, err
}

return &Insert{
Table: tname,
Columns: columns,
Input: input,
Table: tname,
Columns: columns,
Input: input,
Return: list,
ReturnAliases: aliases,
}, nil
}

return nil, p.errorf("unknown DML statement")
}

func (p *parser) parseThenReturn() ([]Expr, []ID, *parseError) {
if p.eat("THEN", "RETURN") {
return p.parseSelectList()
}
return nil, nil, nil
}

func (p *parser) parseUpdateItem() (UpdateItem, *parseError) {
col, err := p.parseTableOrIndexOrColumnName()
if err != nil {
Expand Down
31 changes: 31 additions & 0 deletions spanner/spansql/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ func TestParseDMLStmt(t *testing.T) {
Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}},
},
},
{
"INSERT Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards') THEN RETURN FirstName",
&Insert{
Table: "Singers",
Columns: []ID{ID("SingerId"), ID("FirstName"), ID("LastName")},
Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}},
Return: []Expr{ID("FirstName")},
},
},
{
"INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards')",
&Insert{
Expand All @@ -338,6 +347,28 @@ func TestParseDMLStmt(t *testing.T) {
Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}},
},
},
{
"INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards')" +
" THEN RETURN SingerId, FirstName || ' ' || LastName AS Name",
&Insert{
Table: "Singers",
Columns: []ID{ID("SingerId"), ID("FirstName"), ID("LastName")},
Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}},
Return: []Expr{
ID("SingerId"),
ArithOp{
Op: Concat,
LHS: ArithOp{
Op: Concat,
LHS: ID("FirstName"),
RHS: StringLiteral(" "),
},
RHS: ID("LastName"),
},
},
ReturnAliases: []ID{ID(""), ID("Name")},
},
},
{
"INSERT Singers (SingerId, FirstName, LastName) SELECT * FROM UNNEST ([1, 2, 3]) AS data",
&Insert{
Expand Down
Loading