diff --git a/spanner/doc.go b/spanner/doc.go index 8c30fdcc2a47..db4eb09c4fbd 100644 --- a/spanner/doc.go +++ b/spanner/doc.go @@ -338,10 +338,14 @@ NULL-able STRUCT values. # DML and Partitioned DML Spanner supports DML statements like INSERT, UPDATE and DELETE. Use -ReadWriteTransaction.Update to run DML statements. It returns the number of rows -affected. (You can call use ReadWriteTransaction.Query with a DML statement. The -first call to Next on the resulting RowIterator will return iterator.Done, and -the RowCount field of the iterator will hold the number of affected rows.) +ReadWriteTransaction.Update to run simple DML statements and +ReadWriteTransaction.Query to run DML statements containing "THEN RETURN" +clauses. ReadWriteTransaction.Update returns the number of rows affected. +When using ReadWriteTransaction.Update the number of rows affected is +available via RowIterator.RowCount after RowIterator.Next returns iterator.Done. +(You can call use ReadWriteTransaction.Query with a simple DML statement as +well. In this case the first call to Next on the resulting RowIterator will +return iterator.Done.) For large databases, it may be more efficient to partition the DML statement. Use client.PartitionedUpdate to run a DML statement in this way. Not all DML diff --git a/spanner/integration_test.go b/spanner/integration_test.go index dfa7a94c321b..3afb40c2aaf5 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -2943,6 +2943,130 @@ func TestIntegration_DML(t *testing.T) { } } +func TestIntegration_DMLReturning(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) + defer cleanup() + + query := func(ctx context.Context, tx *ReadWriteTransaction, query string, expectedRows int) ([]*Row, int64) { + iter := tx.Query(ctx, Statement{ + SQL: query, + }) + defer iter.Stop() + var rows []*Row + for len(rows) < expectedRows { + row, err := iter.Next() + if err != nil { + t.Fatalf("wanted row, got error %v", err) + } + rows = append(rows, row) + } + if _, err := iter.Next(); err != iterator.Done { + if err == nil { + t.Fatalf("got more rows than expected, wanted %d rows", expectedRows) + } else { + t.Fatalf("wanted iterator.Done, got error %v", err) + } + } + return rows, iter.RowCount + } + + queries := []string{ + `INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) THEN RETURN Nickname, Balance`, + `INSERT INTO Accounts (AccountId, Balance) VALUES (2, 20) THEN RETURN Balance + 1 AS BalancePlusOne`, + `UPDATE Accounts SET Balance = Balance + 100 WHERE TRUE THEN RETURN AccountId, Balance`, + `DELETE FROM Accounts WHERE AccountId = 1 THEN RETURN *`, + } + + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + queries = []string{ + `INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) RETURNING Nickname, Balance`, + `INSERT INTO Accounts (AccountId, Balance) VALUES (2, 20) RETURNING Balance + 1 AS BalancePlusOne`, + `UPDATE Accounts SET Balance = Balance + 100 WHERE TRUE RETURNING AccountId, Balance`, + `DELETE FROM Accounts WHERE AccountId = 1 RETURNING *`, + } + } + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + // Insert first account, returning Nickname and Balance. + rows, _ := query(ctx, tx, queries[0], 1) + var nick *string + var bal int64 + if err := rows[0].Columns(&nick, &bal); err != nil { + return err + } + if want := (*string)(nil); want != nick { + t.Errorf("got %v, want %v", nick, want) + } + if want := int64(10); want != bal { + t.Errorf("got %d, want %d", bal, want) + } + + // Insert second account, returning Balance+1 + rows, _ = query(ctx, tx, queries[1], 1) + var balancePlusOne int64 + if err := rows[0].Columns(&balancePlusOne); err != nil { + return err + } + if want := int64(21); want != balancePlusOne { + t.Errorf("got %q, want %q", balancePlusOne, want) + } + + // Update both accounts, returning AccountId and Balance for each. + rows, _ = query(ctx, tx, queries[2], 2) + balances := make(map[int64]int64) + for _, row := range rows { + var acctID int64 + var bal int64 + if err := row.Columns(&acctID, &bal); err != nil { + return err + } + balances[acctID] = bal + } + if want := map[int64]int64{1: 110, 2: 120}; !testEqual(want, balances) { + t.Errorf("got %#v, want %#v", balances, want) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + + // Delete an account, returning all columns + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + rows, modifiedRowCount := query(ctx, tx, queries[3], 1) + + if want := int64(1); modifiedRowCount != want { + t.Fatalf("got %v, want %v", modifiedRowCount, want) + } + + var acctID int64 + var nick *string + var bal int64 + if err := rows[0].Columns(&acctID, &nick, &bal); err != nil { + t.Fatal(err) + } + + if want := int64(1); want != acctID { + t.Fatalf("got %v, want %v", acctID, want) + } + if want := (*string)(nil); want != nick { + t.Fatalf("got %v, want %v", nick, want) + } + if want := int64(110); want != bal { + t.Fatalf("got %v, want %v", bal, want) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} + func TestIntegration_StructParametersBind(t *testing.T) { t.Parallel() skipUnsupportedPGTest(t) diff --git a/spanner/spannertest/README.md b/spanner/spannertest/README.md index e737bd6810f2..f41ba59fa6f5 100644 --- a/spanner/spannertest/README.md +++ b/spanner/spannertest/README.md @@ -24,6 +24,7 @@ by ascending esotericism: - SELECT HAVING - more literal types - DEFAULT +- THEN RETURN - expressions that return null for generated columns - generated columns referencing other generated columns - checking dependencies on a generated column before deleting a column diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index 0f6f571c4bde..7d581f50a8dd 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -1466,16 +1466,19 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) { /* DELETE [FROM] target_name [[AS] alias] WHERE condition + THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...] UPDATE target_name [[AS] alias] SET update_item [, ...] WHERE condition + THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...] update_item: path_expression = expression | path_expression = DEFAULT INSERT [INTO] target_name (column_name_1 [, ..., column_name_n] ) input + THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...] input: VALUES (row_1_column_1_expr [, ..., row_1_column_n_expr ] ) @@ -1499,9 +1502,15 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) { if err != nil { return nil, err } + list, aliases, err := p.parseThenReturn() + if err != nil { + return nil, err + } return &Delete{ - Table: tname, - Where: where, + Table: tname, + Where: where, + Return: list, + ReturnAliases: aliases, }, nil } @@ -1536,6 +1545,12 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) { return nil, err } u.Where = where + list, aliases, err := p.parseThenReturn() + if err != nil { + return nil, err + } + u.Return = list + u.ReturnAliases = aliases return u, nil } @@ -1572,16 +1587,30 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) { } } + list, aliases, err := p.parseThenReturn() + if err != nil { + return nil, err + } + return &Insert{ - Table: tname, - Columns: columns, - Input: input, + Table: tname, + Columns: columns, + Input: input, + Return: list, + ReturnAliases: aliases, }, nil } return nil, p.errorf("unknown DML statement") } +func (p *parser) parseThenReturn() ([]Expr, []ID, *parseError) { + if p.eat("THEN", "RETURN") { + return p.parseSelectList() + } + return nil, nil, nil +} + func (p *parser) parseUpdateItem() (UpdateItem, *parseError) { col, err := p.parseTableOrIndexOrColumnName() if err != nil { diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index bcc27eceae42..ccfbabaf9384 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -330,6 +330,15 @@ func TestParseDMLStmt(t *testing.T) { Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}}, }, }, + { + "INSERT Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards') THEN RETURN FirstName", + &Insert{ + Table: "Singers", + Columns: []ID{ID("SingerId"), ID("FirstName"), ID("LastName")}, + Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}}, + Return: []Expr{ID("FirstName")}, + }, + }, { "INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards')", &Insert{ @@ -338,6 +347,28 @@ func TestParseDMLStmt(t *testing.T) { Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}}, }, }, + { + "INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards')" + + " THEN RETURN SingerId, FirstName || ' ' || LastName AS Name", + &Insert{ + Table: "Singers", + Columns: []ID{ID("SingerId"), ID("FirstName"), ID("LastName")}, + Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}}, + Return: []Expr{ + ID("SingerId"), + ArithOp{ + Op: Concat, + LHS: ArithOp{ + Op: Concat, + LHS: ID("FirstName"), + RHS: StringLiteral(" "), + }, + RHS: ID("LastName"), + }, + }, + ReturnAliases: []ID{ID(""), ID("Name")}, + }, + }, { "INSERT Singers (SingerId, FirstName, LastName) SELECT * FROM UNNEST ([1, 2, 3]) AS data", &Insert{ diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index f9587ae8d55b..d4b11add73d7 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -241,7 +241,9 @@ func (do DatabaseOptions) SQL() string { } func (d *Delete) SQL() string { - return "DELETE FROM " + d.Table.SQL() + " WHERE " + d.Where.SQL() + return "DELETE FROM " + d.Table.SQL() + + " WHERE " + d.Where.SQL() + + thenReturnSQL(d.Return, d.ReturnAliases) } func (u *Update) SQL() string { @@ -258,6 +260,7 @@ func (u *Update) SQL() string { } } str += " WHERE " + u.Where.SQL() + str += thenReturnSQL(u.Return, u.ReturnAliases) return str } @@ -271,6 +274,7 @@ func (i *Insert) SQL() string { } str += ") " str += i.Input.SQL() + str += thenReturnSQL(i.Return, i.ReturnAliases) return str } @@ -293,6 +297,16 @@ func (v Values) SQL() string { return str } +func thenReturnSQL(list []Expr, aliases []ID) string { + if len(list) == 0 { + return "" + } + var sb strings.Builder + sb.WriteString(" THEN RETURN ") + addSelectList(&sb, list, aliases) + return sb.String() +} + func (cd ColumnDef) SQL() string { str := cd.Name.SQL() + " " + cd.Type.SQL() if cd.NotNull { @@ -411,19 +425,7 @@ func (sel Select) addSQL(sb *strings.Builder) { if sel.Distinct { sb.WriteString("DISTINCT ") } - for i, e := range sel.List { - if i > 0 { - sb.WriteString(", ") - } - e.addSQL(sb) - if len(sel.ListAliases) > 0 { - alias := sel.ListAliases[i] - if alias != "" { - sb.WriteString(" AS ") - sb.WriteString(alias.SQL()) - } - } - } + addSelectList(sb, sel.List, sel.ListAliases) if len(sel.From) > 0 { sb.WriteString(" FROM ") for i, f := range sel.From { @@ -443,6 +445,22 @@ func (sel Select) addSQL(sb *strings.Builder) { } } +func addSelectList(sb *strings.Builder, list []Expr, aliases []ID) { + for i, e := range list { + if i > 0 { + sb.WriteString(", ") + } + e.addSQL(sb) + if len(aliases) > 0 { + alias := aliases[i] + if alias != "" { + sb.WriteString(" AS ") + sb.WriteString(alias.SQL()) + } + } + } +} + func (sft SelectFromTable) SQL() string { str := sft.Table.SQL() if len(sft.Hints) > 0 { diff --git a/spanner/spansql/sql_test.go b/spanner/spansql/sql_test.go index f68f33754ef5..39239b170747 100644 --- a/spanner/spansql/sql_test.go +++ b/spanner/spansql/sql_test.go @@ -420,6 +420,20 @@ func TestSQL(t *testing.T) { `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, "Marc", "Richards")`, reparseDML, }, + { + &Insert{ + Table: "Singers", + Columns: []ID{ID("SingerId"), ID("FirstName"), ID("LastName")}, + Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}}, + Return: []Expr{ + ID("SingerId"), + ID("FirstName"), + }, + }, + `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, "Marc", "Richards")` + + ` THEN RETURN SingerId, FirstName`, + reparseDML, + }, { &Delete{ Table: "Ta", @@ -432,6 +446,19 @@ func TestSQL(t *testing.T) { "DELETE FROM Ta WHERE C > 2", reparseDML, }, + { + &Delete{ + Table: "Ta", + Where: ComparisonOp{ + LHS: ID("C"), + Op: Gt, + RHS: IntegerLiteral(2), + }, + Return: []Expr{Star}, + }, + "DELETE FROM Ta WHERE C > 2 THEN RETURN *", + reparseDML, + }, { &Update{ Table: "Ta", @@ -447,6 +474,29 @@ func TestSQL(t *testing.T) { `UPDATE Ta SET Cb = 4, Ce = "wow", Cf = Cg, Cg = NULL, Ch = DEFAULT WHERE Ca`, reparseDML, }, + { + &Update{ + Table: "Ta", + Items: []UpdateItem{ + {Column: "Cb", Value: IntegerLiteral(4)}, + {Column: "Ce", Value: StringLiteral("wow")}, + {Column: "Cf", Value: ID("Cg")}, + {Column: "Cg", Value: Null}, + {Column: "Ch", Value: nil}, + }, + Where: ID("Ca"), + Return: []Expr{ + ID("Cb"), ID("Ce"), + }, + ReturnAliases: []ID{ + ID("RenamedCb"), + ID(""), + }, + }, + `UPDATE Ta SET Cb = 4, Ce = "wow", Cf = Cg, Cg = NULL, Ch = DEFAULT WHERE Ca` + + ` THEN RETURN Cb AS RenamedCb, Ce`, + reparseDML, + }, { Query{ Select: Select{ diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index 87c681c7d259..de7b5a53e08b 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -300,6 +300,12 @@ type DatabaseOptions struct { type Delete struct { Table ID Where BoolExpr + // Expressions in a THEN RETURN clause. + Return []Expr + // If the THEN RETURN list has explicit aliases ("AS alias"), + // ReturnAliases will be populated 1:1 with List; + // aliases that are present will be non-empty. + ReturnAliases []ID // TODO: Alias } @@ -313,6 +319,12 @@ type Insert struct { Table ID Columns []ID Input ValuesOrSelect + // Expressions in a THEN RETURN clause. + Return []Expr + // If the THEN RETURN list has explicit aliases ("AS alias"), + // ReturnAliases will be populated 1:1 with List; + // aliases that are present will be non-empty. + ReturnAliases []ID } // Values represents one or more lists of expressions passed to an `INSERT` statement. @@ -337,6 +349,12 @@ type Update struct { Table ID Items []UpdateItem Where BoolExpr + // Expressions in a THEN RETURN clause. + Return []Expr + // If the THEN RETURN list has explicit aliases ("AS alias"), + // ReturnAliases will be populated 1:1 with List; + // aliases that are present will be non-empty. + ReturnAliases []ID // TODO: Alias }