From c07243ca733ce499bc4359229082abfadec25d79 Mon Sep 17 00:00:00 2001 From: Chris Thunes Date: Thu, 11 Aug 2022 16:56:17 +0000 Subject: [PATCH] feat: Add tests and parser support for THEN RETURN This change adds THEN RETURN support for DML statements. The existing "Query" methods already support executing such statements, so this change is limited to updating documentation, adding integration tests, and updating the spansql package. Support has not been added to the spannertest in-memory implementation of Spanner, but the feature absence has been noted in the corresponding README. --- spanner/doc.go | 12 ++-- spanner/integration_test.go | 124 +++++++++++++++++++++++++++++++++ spanner/spannertest/README.md | 1 + spanner/spansql/parser.go | 39 +++++++++-- spanner/spansql/parser_test.go | 31 +++++++++ spanner/spansql/sql.go | 46 ++++++++---- spanner/spansql/sql_test.go | 50 +++++++++++++ spanner/spansql/types.go | 18 +++++ 8 files changed, 298 insertions(+), 23 deletions(-) diff --git a/spanner/doc.go b/spanner/doc.go index 8c30fdcc2a47..db4eb09c4fbd 100644 --- a/spanner/doc.go +++ b/spanner/doc.go @@ -338,10 +338,14 @@ NULL-able STRUCT values. # DML and Partitioned DML Spanner supports DML statements like INSERT, UPDATE and DELETE. Use -ReadWriteTransaction.Update to run DML statements. It returns the number of rows -affected. (You can call use ReadWriteTransaction.Query with a DML statement. The -first call to Next on the resulting RowIterator will return iterator.Done, and -the RowCount field of the iterator will hold the number of affected rows.) +ReadWriteTransaction.Update to run simple DML statements and +ReadWriteTransaction.Query to run DML statements containing "THEN RETURN" +clauses. ReadWriteTransaction.Update returns the number of rows affected. +When using ReadWriteTransaction.Update the number of rows affected is +available via RowIterator.RowCount after RowIterator.Next returns iterator.Done. +(You can call use ReadWriteTransaction.Query with a simple DML statement as +well. In this case the first call to Next on the resulting RowIterator will +return iterator.Done.) For large databases, it may be more efficient to partition the DML statement. Use client.PartitionedUpdate to run a DML statement in this way. Not all DML diff --git a/spanner/integration_test.go b/spanner/integration_test.go index dfa7a94c321b..3afb40c2aaf5 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -2943,6 +2943,130 @@ func TestIntegration_DML(t *testing.T) { } } +func TestIntegration_DMLReturning(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) + defer cleanup() + + query := func(ctx context.Context, tx *ReadWriteTransaction, query string, expectedRows int) ([]*Row, int64) { + iter := tx.Query(ctx, Statement{ + SQL: query, + }) + defer iter.Stop() + var rows []*Row + for len(rows) < expectedRows { + row, err := iter.Next() + if err != nil { + t.Fatalf("wanted row, got error %v", err) + } + rows = append(rows, row) + } + if _, err := iter.Next(); err != iterator.Done { + if err == nil { + t.Fatalf("got more rows than expected, wanted %d rows", expectedRows) + } else { + t.Fatalf("wanted iterator.Done, got error %v", err) + } + } + return rows, iter.RowCount + } + + queries := []string{ + `INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) THEN RETURN Nickname, Balance`, + `INSERT INTO Accounts (AccountId, Balance) VALUES (2, 20) THEN RETURN Balance + 1 AS BalancePlusOne`, + `UPDATE Accounts SET Balance = Balance + 100 WHERE TRUE THEN RETURN AccountId, Balance`, + `DELETE FROM Accounts WHERE AccountId = 1 THEN RETURN *`, + } + + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + queries = []string{ + `INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) RETURNING Nickname, Balance`, + `INSERT INTO Accounts (AccountId, Balance) VALUES (2, 20) RETURNING Balance + 1 AS BalancePlusOne`, + `UPDATE Accounts SET Balance = Balance + 100 WHERE TRUE RETURNING AccountId, Balance`, + `DELETE FROM Accounts WHERE AccountId = 1 RETURNING *`, + } + } + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + // Insert first account, returning Nickname and Balance. + rows, _ := query(ctx, tx, queries[0], 1) + var nick *string + var bal int64 + if err := rows[0].Columns(&nick, &bal); err != nil { + return err + } + if want := (*string)(nil); want != nick { + t.Errorf("got %v, want %v", nick, want) + } + if want := int64(10); want != bal { + t.Errorf("got %d, want %d", bal, want) + } + + // Insert second account, returning Balance+1 + rows, _ = query(ctx, tx, queries[1], 1) + var balancePlusOne int64 + if err := rows[0].Columns(&balancePlusOne); err != nil { + return err + } + if want := int64(21); want != balancePlusOne { + t.Errorf("got %q, want %q", balancePlusOne, want) + } + + // Update both accounts, returning AccountId and Balance for each. + rows, _ = query(ctx, tx, queries[2], 2) + balances := make(map[int64]int64) + for _, row := range rows { + var acctID int64 + var bal int64 + if err := row.Columns(&acctID, &bal); err != nil { + return err + } + balances[acctID] = bal + } + if want := map[int64]int64{1: 110, 2: 120}; !testEqual(want, balances) { + t.Errorf("got %#v, want %#v", balances, want) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + + // Delete an account, returning all columns + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + rows, modifiedRowCount := query(ctx, tx, queries[3], 1) + + if want := int64(1); modifiedRowCount != want { + t.Fatalf("got %v, want %v", modifiedRowCount, want) + } + + var acctID int64 + var nick *string + var bal int64 + if err := rows[0].Columns(&acctID, &nick, &bal); err != nil { + t.Fatal(err) + } + + if want := int64(1); want != acctID { + t.Fatalf("got %v, want %v", acctID, want) + } + if want := (*string)(nil); want != nick { + t.Fatalf("got %v, want %v", nick, want) + } + if want := int64(110); want != bal { + t.Fatalf("got %v, want %v", bal, want) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} + func TestIntegration_StructParametersBind(t *testing.T) { t.Parallel() skipUnsupportedPGTest(t) diff --git a/spanner/spannertest/README.md b/spanner/spannertest/README.md index e737bd6810f2..f41ba59fa6f5 100644 --- a/spanner/spannertest/README.md +++ b/spanner/spannertest/README.md @@ -24,6 +24,7 @@ by ascending esotericism: - SELECT HAVING - more literal types - DEFAULT +- THEN RETURN - expressions that return null for generated columns - generated columns referencing other generated columns - checking dependencies on a generated column before deleting a column diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index 0f6f571c4bde..7d581f50a8dd 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -1466,16 +1466,19 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) { /* DELETE [FROM] target_name [[AS] alias] WHERE condition + THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...] UPDATE target_name [[AS] alias] SET update_item [, ...] WHERE condition + THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...] update_item: path_expression = expression | path_expression = DEFAULT INSERT [INTO] target_name (column_name_1 [, ..., column_name_n] ) input + THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...] input: VALUES (row_1_column_1_expr [, ..., row_1_column_n_expr ] ) @@ -1499,9 +1502,15 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) { if err != nil { return nil, err } + list, aliases, err := p.parseThenReturn() + if err != nil { + return nil, err + } return &Delete{ - Table: tname, - Where: where, + Table: tname, + Where: where, + Return: list, + ReturnAliases: aliases, }, nil } @@ -1536,6 +1545,12 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) { return nil, err } u.Where = where + list, aliases, err := p.parseThenReturn() + if err != nil { + return nil, err + } + u.Return = list + u.ReturnAliases = aliases return u, nil } @@ -1572,16 +1587,30 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) { } } + list, aliases, err := p.parseThenReturn() + if err != nil { + return nil, err + } + return &Insert{ - Table: tname, - Columns: columns, - Input: input, + Table: tname, + Columns: columns, + Input: input, + Return: list, + ReturnAliases: aliases, }, nil } return nil, p.errorf("unknown DML statement") } +func (p *parser) parseThenReturn() ([]Expr, []ID, *parseError) { + if p.eat("THEN", "RETURN") { + return p.parseSelectList() + } + return nil, nil, nil +} + func (p *parser) parseUpdateItem() (UpdateItem, *parseError) { col, err := p.parseTableOrIndexOrColumnName() if err != nil { diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index bcc27eceae42..ccfbabaf9384 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -330,6 +330,15 @@ func TestParseDMLStmt(t *testing.T) { Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}}, }, }, + { + "INSERT Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards') THEN RETURN FirstName", + &Insert{ + Table: "Singers", + Columns: []ID{ID("SingerId"), ID("FirstName"), ID("LastName")}, + Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}}, + Return: []Expr{ID("FirstName")}, + }, + }, { "INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards')", &Insert{ @@ -338,6 +347,28 @@ func TestParseDMLStmt(t *testing.T) { Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}}, }, }, + { + "INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards')" + + " THEN RETURN SingerId, FirstName || ' ' || LastName AS Name", + &Insert{ + Table: "Singers", + Columns: []ID{ID("SingerId"), ID("FirstName"), ID("LastName")}, + Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}}, + Return: []Expr{ + ID("SingerId"), + ArithOp{ + Op: Concat, + LHS: ArithOp{ + Op: Concat, + LHS: ID("FirstName"), + RHS: StringLiteral(" "), + }, + RHS: ID("LastName"), + }, + }, + ReturnAliases: []ID{ID(""), ID("Name")}, + }, + }, { "INSERT Singers (SingerId, FirstName, LastName) SELECT * FROM UNNEST ([1, 2, 3]) AS data", &Insert{ diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index f9587ae8d55b..d4b11add73d7 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -241,7 +241,9 @@ func (do DatabaseOptions) SQL() string { } func (d *Delete) SQL() string { - return "DELETE FROM " + d.Table.SQL() + " WHERE " + d.Where.SQL() + return "DELETE FROM " + d.Table.SQL() + + " WHERE " + d.Where.SQL() + + thenReturnSQL(d.Return, d.ReturnAliases) } func (u *Update) SQL() string { @@ -258,6 +260,7 @@ func (u *Update) SQL() string { } } str += " WHERE " + u.Where.SQL() + str += thenReturnSQL(u.Return, u.ReturnAliases) return str } @@ -271,6 +274,7 @@ func (i *Insert) SQL() string { } str += ") " str += i.Input.SQL() + str += thenReturnSQL(i.Return, i.ReturnAliases) return str } @@ -293,6 +297,16 @@ func (v Values) SQL() string { return str } +func thenReturnSQL(list []Expr, aliases []ID) string { + if len(list) == 0 { + return "" + } + var sb strings.Builder + sb.WriteString(" THEN RETURN ") + addSelectList(&sb, list, aliases) + return sb.String() +} + func (cd ColumnDef) SQL() string { str := cd.Name.SQL() + " " + cd.Type.SQL() if cd.NotNull { @@ -411,19 +425,7 @@ func (sel Select) addSQL(sb *strings.Builder) { if sel.Distinct { sb.WriteString("DISTINCT ") } - for i, e := range sel.List { - if i > 0 { - sb.WriteString(", ") - } - e.addSQL(sb) - if len(sel.ListAliases) > 0 { - alias := sel.ListAliases[i] - if alias != "" { - sb.WriteString(" AS ") - sb.WriteString(alias.SQL()) - } - } - } + addSelectList(sb, sel.List, sel.ListAliases) if len(sel.From) > 0 { sb.WriteString(" FROM ") for i, f := range sel.From { @@ -443,6 +445,22 @@ func (sel Select) addSQL(sb *strings.Builder) { } } +func addSelectList(sb *strings.Builder, list []Expr, aliases []ID) { + for i, e := range list { + if i > 0 { + sb.WriteString(", ") + } + e.addSQL(sb) + if len(aliases) > 0 { + alias := aliases[i] + if alias != "" { + sb.WriteString(" AS ") + sb.WriteString(alias.SQL()) + } + } + } +} + func (sft SelectFromTable) SQL() string { str := sft.Table.SQL() if len(sft.Hints) > 0 { diff --git a/spanner/spansql/sql_test.go b/spanner/spansql/sql_test.go index f68f33754ef5..39239b170747 100644 --- a/spanner/spansql/sql_test.go +++ b/spanner/spansql/sql_test.go @@ -420,6 +420,20 @@ func TestSQL(t *testing.T) { `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, "Marc", "Richards")`, reparseDML, }, + { + &Insert{ + Table: "Singers", + Columns: []ID{ID("SingerId"), ID("FirstName"), ID("LastName")}, + Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}}, + Return: []Expr{ + ID("SingerId"), + ID("FirstName"), + }, + }, + `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, "Marc", "Richards")` + + ` THEN RETURN SingerId, FirstName`, + reparseDML, + }, { &Delete{ Table: "Ta", @@ -432,6 +446,19 @@ func TestSQL(t *testing.T) { "DELETE FROM Ta WHERE C > 2", reparseDML, }, + { + &Delete{ + Table: "Ta", + Where: ComparisonOp{ + LHS: ID("C"), + Op: Gt, + RHS: IntegerLiteral(2), + }, + Return: []Expr{Star}, + }, + "DELETE FROM Ta WHERE C > 2 THEN RETURN *", + reparseDML, + }, { &Update{ Table: "Ta", @@ -447,6 +474,29 @@ func TestSQL(t *testing.T) { `UPDATE Ta SET Cb = 4, Ce = "wow", Cf = Cg, Cg = NULL, Ch = DEFAULT WHERE Ca`, reparseDML, }, + { + &Update{ + Table: "Ta", + Items: []UpdateItem{ + {Column: "Cb", Value: IntegerLiteral(4)}, + {Column: "Ce", Value: StringLiteral("wow")}, + {Column: "Cf", Value: ID("Cg")}, + {Column: "Cg", Value: Null}, + {Column: "Ch", Value: nil}, + }, + Where: ID("Ca"), + Return: []Expr{ + ID("Cb"), ID("Ce"), + }, + ReturnAliases: []ID{ + ID("RenamedCb"), + ID(""), + }, + }, + `UPDATE Ta SET Cb = 4, Ce = "wow", Cf = Cg, Cg = NULL, Ch = DEFAULT WHERE Ca` + + ` THEN RETURN Cb AS RenamedCb, Ce`, + reparseDML, + }, { Query{ Select: Select{ diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index 87c681c7d259..de7b5a53e08b 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -300,6 +300,12 @@ type DatabaseOptions struct { type Delete struct { Table ID Where BoolExpr + // Expressions in a THEN RETURN clause. + Return []Expr + // If the THEN RETURN list has explicit aliases ("AS alias"), + // ReturnAliases will be populated 1:1 with List; + // aliases that are present will be non-empty. + ReturnAliases []ID // TODO: Alias } @@ -313,6 +319,12 @@ type Insert struct { Table ID Columns []ID Input ValuesOrSelect + // Expressions in a THEN RETURN clause. + Return []Expr + // If the THEN RETURN list has explicit aliases ("AS alias"), + // ReturnAliases will be populated 1:1 with List; + // aliases that are present will be non-empty. + ReturnAliases []ID } // Values represents one or more lists of expressions passed to an `INSERT` statement. @@ -337,6 +349,12 @@ type Update struct { Table ID Items []UpdateItem Where BoolExpr + // Expressions in a THEN RETURN clause. + Return []Expr + // If the THEN RETURN list has explicit aliases ("AS alias"), + // ReturnAliases will be populated 1:1 with List; + // aliases that are present will be non-empty. + ReturnAliases []ID // TODO: Alias }