Skip to content

Commit

Permalink
feat: Add tests and parser support for THEN RETURN
Browse files Browse the repository at this point in the history
This change adds THEN RETURN support for DML statements. The existing
"Query" methods already support executing such statements, so this
change is limited to updating documentation, adding integration tests,
and updating the spansql package. Support has not been added to the
spannertest in-memory implementation of Spanner, but the feature absence
has been noted in the corresponding README.
  • Loading branch information
c2nes committed Aug 15, 2022
1 parent 370e23e commit c07243c
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 23 deletions.
12 changes: 8 additions & 4 deletions spanner/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,14 @@ NULL-able STRUCT values.
# DML and Partitioned DML
Spanner supports DML statements like INSERT, UPDATE and DELETE. Use
ReadWriteTransaction.Update to run DML statements. It returns the number of rows
affected. (You can call use ReadWriteTransaction.Query with a DML statement. The
first call to Next on the resulting RowIterator will return iterator.Done, and
the RowCount field of the iterator will hold the number of affected rows.)
ReadWriteTransaction.Update to run simple DML statements and
ReadWriteTransaction.Query to run DML statements containing "THEN RETURN"
clauses. ReadWriteTransaction.Update returns the number of rows affected.
When using ReadWriteTransaction.Update the number of rows affected is
available via RowIterator.RowCount after RowIterator.Next returns iterator.Done.
(You can call use ReadWriteTransaction.Query with a simple DML statement as
well. In this case the first call to Next on the resulting RowIterator will
return iterator.Done.)
For large databases, it may be more efficient to partition the DML statement.
Use client.PartitionedUpdate to run a DML statement in this way. Not all DML
Expand Down
124 changes: 124 additions & 0 deletions spanner/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2943,6 +2943,130 @@ func TestIntegration_DML(t *testing.T) {
}
}

func TestIntegration_DMLReturning(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements])
defer cleanup()

query := func(ctx context.Context, tx *ReadWriteTransaction, query string, expectedRows int) ([]*Row, int64) {
iter := tx.Query(ctx, Statement{
SQL: query,
})
defer iter.Stop()
var rows []*Row
for len(rows) < expectedRows {
row, err := iter.Next()
if err != nil {
t.Fatalf("wanted row, got error %v", err)
}
rows = append(rows, row)
}
if _, err := iter.Next(); err != iterator.Done {
if err == nil {
t.Fatalf("got more rows than expected, wanted %d rows", expectedRows)
} else {
t.Fatalf("wanted iterator.Done, got error %v", err)
}
}
return rows, iter.RowCount
}

queries := []string{
`INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) THEN RETURN Nickname, Balance`,
`INSERT INTO Accounts (AccountId, Balance) VALUES (2, 20) THEN RETURN Balance + 1 AS BalancePlusOne`,
`UPDATE Accounts SET Balance = Balance + 100 WHERE TRUE THEN RETURN AccountId, Balance`,
`DELETE FROM Accounts WHERE AccountId = 1 THEN RETURN *`,
}

if testDialect == adminpb.DatabaseDialect_POSTGRESQL {
queries = []string{
`INSERT INTO Accounts (AccountId, Balance) VALUES (1, 10) RETURNING Nickname, Balance`,
`INSERT INTO Accounts (AccountId, Balance) VALUES (2, 20) RETURNING Balance + 1 AS BalancePlusOne`,
`UPDATE Accounts SET Balance = Balance + 100 WHERE TRUE RETURNING AccountId, Balance`,
`DELETE FROM Accounts WHERE AccountId = 1 RETURNING *`,
}
}
_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
// Insert first account, returning Nickname and Balance.
rows, _ := query(ctx, tx, queries[0], 1)
var nick *string
var bal int64
if err := rows[0].Columns(&nick, &bal); err != nil {
return err
}
if want := (*string)(nil); want != nick {
t.Errorf("got %v, want %v", nick, want)
}
if want := int64(10); want != bal {
t.Errorf("got %d, want %d", bal, want)
}

// Insert second account, returning Balance+1
rows, _ = query(ctx, tx, queries[1], 1)
var balancePlusOne int64
if err := rows[0].Columns(&balancePlusOne); err != nil {
return err
}
if want := int64(21); want != balancePlusOne {
t.Errorf("got %q, want %q", balancePlusOne, want)
}

// Update both accounts, returning AccountId and Balance for each.
rows, _ = query(ctx, tx, queries[2], 2)
balances := make(map[int64]int64)
for _, row := range rows {
var acctID int64
var bal int64
if err := row.Columns(&acctID, &bal); err != nil {
return err
}
balances[acctID] = bal
}
if want := map[int64]int64{1: 110, 2: 120}; !testEqual(want, balances) {
t.Errorf("got %#v, want %#v", balances, want)
}

return nil
})
if err != nil {
t.Fatal(err)
}

// Delete an account, returning all columns
_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
rows, modifiedRowCount := query(ctx, tx, queries[3], 1)

if want := int64(1); modifiedRowCount != want {
t.Fatalf("got %v, want %v", modifiedRowCount, want)
}

var acctID int64
var nick *string
var bal int64
if err := rows[0].Columns(&acctID, &nick, &bal); err != nil {
t.Fatal(err)
}

if want := int64(1); want != acctID {
t.Fatalf("got %v, want %v", acctID, want)
}
if want := (*string)(nil); want != nick {
t.Fatalf("got %v, want %v", nick, want)
}
if want := int64(110); want != bal {
t.Fatalf("got %v, want %v", bal, want)
}

return nil
})
if err != nil {
t.Fatal(err)
}
}

func TestIntegration_StructParametersBind(t *testing.T) {
t.Parallel()
skipUnsupportedPGTest(t)
Expand Down
1 change: 1 addition & 0 deletions spanner/spannertest/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ by ascending esotericism:
- SELECT HAVING
- more literal types
- DEFAULT
- THEN RETURN
- expressions that return null for generated columns
- generated columns referencing other generated columns
- checking dependencies on a generated column before deleting a column
Expand Down
39 changes: 34 additions & 5 deletions spanner/spansql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -1466,16 +1466,19 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) {
/*
DELETE [FROM] target_name [[AS] alias]
WHERE condition
THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...]
UPDATE target_name [[AS] alias]
SET update_item [, ...]
WHERE condition
THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...]
update_item: path_expression = expression | path_expression = DEFAULT
INSERT [INTO] target_name
(column_name_1 [, ..., column_name_n] )
input
THEN RETURN { [ expression. ]* | expression [ [ AS ] alias ] } [, ...]
input:
VALUES (row_1_column_1_expr [, ..., row_1_column_n_expr ] )
Expand All @@ -1499,9 +1502,15 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) {
if err != nil {
return nil, err
}
list, aliases, err := p.parseThenReturn()
if err != nil {
return nil, err
}
return &Delete{
Table: tname,
Where: where,
Table: tname,
Where: where,
Return: list,
ReturnAliases: aliases,
}, nil
}

Expand Down Expand Up @@ -1536,6 +1545,12 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) {
return nil, err
}
u.Where = where
list, aliases, err := p.parseThenReturn()
if err != nil {
return nil, err
}
u.Return = list
u.ReturnAliases = aliases
return u, nil
}

Expand Down Expand Up @@ -1572,16 +1587,30 @@ func (p *parser) parseDMLStmt() (DMLStmt, *parseError) {
}
}

list, aliases, err := p.parseThenReturn()
if err != nil {
return nil, err
}

return &Insert{
Table: tname,
Columns: columns,
Input: input,
Table: tname,
Columns: columns,
Input: input,
Return: list,
ReturnAliases: aliases,
}, nil
}

return nil, p.errorf("unknown DML statement")
}

func (p *parser) parseThenReturn() ([]Expr, []ID, *parseError) {
if p.eat("THEN", "RETURN") {
return p.parseSelectList()
}
return nil, nil, nil
}

func (p *parser) parseUpdateItem() (UpdateItem, *parseError) {
col, err := p.parseTableOrIndexOrColumnName()
if err != nil {
Expand Down
31 changes: 31 additions & 0 deletions spanner/spansql/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ func TestParseDMLStmt(t *testing.T) {
Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}},
},
},
{
"INSERT Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards') THEN RETURN FirstName",
&Insert{
Table: "Singers",
Columns: []ID{ID("SingerId"), ID("FirstName"), ID("LastName")},
Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}},
Return: []Expr{ID("FirstName")},
},
},
{
"INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards')",
&Insert{
Expand All @@ -338,6 +347,28 @@ func TestParseDMLStmt(t *testing.T) {
Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}},
},
},
{
"INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, 'Marc', 'Richards')" +
" THEN RETURN SingerId, FirstName || ' ' || LastName AS Name",
&Insert{
Table: "Singers",
Columns: []ID{ID("SingerId"), ID("FirstName"), ID("LastName")},
Input: Values{{IntegerLiteral(1), StringLiteral("Marc"), StringLiteral("Richards")}},
Return: []Expr{
ID("SingerId"),
ArithOp{
Op: Concat,
LHS: ArithOp{
Op: Concat,
LHS: ID("FirstName"),
RHS: StringLiteral(" "),
},
RHS: ID("LastName"),
},
},
ReturnAliases: []ID{ID(""), ID("Name")},
},
},
{
"INSERT Singers (SingerId, FirstName, LastName) SELECT * FROM UNNEST ([1, 2, 3]) AS data",
&Insert{
Expand Down
46 changes: 32 additions & 14 deletions spanner/spansql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ func (do DatabaseOptions) SQL() string {
}

func (d *Delete) SQL() string {
return "DELETE FROM " + d.Table.SQL() + " WHERE " + d.Where.SQL()
return "DELETE FROM " + d.Table.SQL() +
" WHERE " + d.Where.SQL() +
thenReturnSQL(d.Return, d.ReturnAliases)
}

func (u *Update) SQL() string {
Expand All @@ -258,6 +260,7 @@ func (u *Update) SQL() string {
}
}
str += " WHERE " + u.Where.SQL()
str += thenReturnSQL(u.Return, u.ReturnAliases)
return str
}

Expand All @@ -271,6 +274,7 @@ func (i *Insert) SQL() string {
}
str += ") "
str += i.Input.SQL()
str += thenReturnSQL(i.Return, i.ReturnAliases)
return str
}

Expand All @@ -293,6 +297,16 @@ func (v Values) SQL() string {
return str
}

func thenReturnSQL(list []Expr, aliases []ID) string {
if len(list) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString(" THEN RETURN ")
addSelectList(&sb, list, aliases)
return sb.String()
}

func (cd ColumnDef) SQL() string {
str := cd.Name.SQL() + " " + cd.Type.SQL()
if cd.NotNull {
Expand Down Expand Up @@ -411,19 +425,7 @@ func (sel Select) addSQL(sb *strings.Builder) {
if sel.Distinct {
sb.WriteString("DISTINCT ")
}
for i, e := range sel.List {
if i > 0 {
sb.WriteString(", ")
}
e.addSQL(sb)
if len(sel.ListAliases) > 0 {
alias := sel.ListAliases[i]
if alias != "" {
sb.WriteString(" AS ")
sb.WriteString(alias.SQL())
}
}
}
addSelectList(sb, sel.List, sel.ListAliases)
if len(sel.From) > 0 {
sb.WriteString(" FROM ")
for i, f := range sel.From {
Expand All @@ -443,6 +445,22 @@ func (sel Select) addSQL(sb *strings.Builder) {
}
}

func addSelectList(sb *strings.Builder, list []Expr, aliases []ID) {
for i, e := range list {
if i > 0 {
sb.WriteString(", ")
}
e.addSQL(sb)
if len(aliases) > 0 {
alias := aliases[i]
if alias != "" {
sb.WriteString(" AS ")
sb.WriteString(alias.SQL())
}
}
}
}

func (sft SelectFromTable) SQL() string {
str := sft.Table.SQL()
if len(sft.Hints) > 0 {
Expand Down
Loading

0 comments on commit c07243c

Please sign in to comment.