From 4067f4eabbba9e5e3265754339709fb273b73e4c Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Tue, 23 Apr 2024 13:48:20 -0700 Subject: [PATCH] feat(datastore): Adding BeginLater transaction option (#8972) * feat(datastore): new_transaction consistency --- datastore/datastore.go | 6 +- datastore/integration_test.go | 250 +++++++++++++++++++++ datastore/mock_test.go | 30 ++- datastore/query.go | 44 ++-- datastore/query_test.go | 18 +- datastore/transaction.go | 208 +++++++++++++----- datastore/transaction_test.go | 395 ++++++++++++++++++++++++++++++++++ 7 files changed, 876 insertions(+), 75 deletions(-) diff --git a/datastore/datastore.go b/datastore/datastore.go index 52680ea275f5..61020bf175cd 100644 --- a/datastore/datastore.go +++ b/datastore/datastore.go @@ -398,7 +398,8 @@ func (c *Client) Get(ctx context.Context, key *Key, dst interface{}) (err error) } } - // TODO: Use transaction ID returned by get + // Since opts does not contain Transaction option, 'get' call below will return nil + // as transaction id which can be ignored _, err = c.get(ctx, []*Key{key}, []interface{}{dst}, opts) if me, ok := err.(MultiError); ok { return me[0] @@ -432,7 +433,8 @@ func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst interface{}) (er } } - // TODO: Use transaction ID returned by get + // Since opts does not contain Transaction option, 'get' call below will return nil + // as transaction id which can be ignored _, err = c.get(ctx, keys, dst, opts) return err } diff --git a/datastore/integration_test.go b/datastore/integration_test.go index ab0436bdff5e..3c94a63a92a1 100644 --- a/datastore/integration_test.go +++ b/datastore/integration_test.go @@ -753,6 +753,256 @@ func TestIntegration_Filters(t *testing.T) { }) } +func populateData(t *testing.T, client *Client, childrenCount int, time int64, testKey string) ([]*Key, *Key, func()) { + ctx := context.Background() + parent := NameKey("SQParent", keyPrefix+testKey+suffix, nil) + + children := []*SQChild{} + + for i := 0; i < childrenCount; i++ { + children = append(children, &SQChild{I: i, T: time, U: time, V: 1.5, W: "str"}) + } + keys := make([]*Key, childrenCount) + for i := range keys { + keys[i] = NameKey("SQChild", "sqChild"+fmt.Sprint(i), parent) + } + keys, err := client.PutMulti(ctx, keys, children) + if err != nil { + t.Fatalf("client.PutMulti: %v", err) + } + + cleanup := func() { + err := client.DeleteMulti(ctx, keys) + if err != nil { + t.Errorf("client.DeleteMulti: %v", err) + } + } + return keys, parent, cleanup +} + +type RunTransactionResult struct { + runTime float64 + err error +} + +func TestIntegration_BeginLaterPerf(t *testing.T) { + if testing.Short() { + t.Skip("Integration tests skipped in short mode") + } + runOptions := []bool{true, false} // whether BeginLater transaction option is used + var avgRunTimes [2]float64 // In seconds + numRepetitions := 10 + numKeys := 10 + + res := make(chan RunTransactionResult) + for i, runOption := range runOptions { + sumRunTime := float64(0) + + // Create client + ctx := context.Background() + client := newTestClient(ctx, t) + defer client.Close() + + // Populate data + now := timeNow.Truncate(time.Millisecond).Unix() + keys, _, cleanupData := populateData(t, client, numKeys, now, "BeginLaterPerf"+fmt.Sprint(runOption)+fmt.Sprint(now)) + defer cleanupData() + + for rep := 0; rep < numRepetitions; rep++ { + go runTransaction(ctx, client, keys, res, runOption, t) + } + for rep := 0; rep < numRepetitions; rep++ { + runTransactionResult := <-res + if runTransactionResult.err != nil { + t.Fatal(runTransactionResult.err) + } + sumRunTime += runTransactionResult.runTime + } + + avgRunTimes[i] = sumRunTime / float64(numRepetitions) + } + improvement := ((avgRunTimes[1] - avgRunTimes[0]) / avgRunTimes[1]) * 100 + if improvement < 0 { + t.Logf("Run times:: with BeginLater: %.3fs, without BeginLater: %.3fs. improvement: %.2f%%", avgRunTimes[0], avgRunTimes[1], improvement) + t.Fatal("No perf improvement because of new transaction consistency type.") + } +} + +func runTransaction(ctx context.Context, client *Client, keys []*Key, res chan RunTransactionResult, beginLater bool, t *testing.T) { + + numKeys := len(keys) + txOpts := []TransactionOption{} + if beginLater { + txOpts = append(txOpts, BeginLater) + } + + start := time.Now() + // Create transaction + tx, err := client.NewTransaction(ctx, txOpts...) + if err != nil { + runTransactionResult := RunTransactionResult{ + err: fmt.Errorf("Failed to create transaction: %v", err), + } + res <- runTransactionResult + return + } + + // Perform operations in transaction + dst := make([]*SQChild, numKeys) + if err := tx.GetMulti(keys, dst); err != nil { + runTransactionResult := RunTransactionResult{ + err: fmt.Errorf("GetMulti got: %v, want: nil", err), + } + res <- runTransactionResult + return + } + if _, err := tx.PutMulti(keys, dst); err != nil { + runTransactionResult := RunTransactionResult{ + err: fmt.Errorf("PutMulti got: %v, want: nil", err), + } + res <- runTransactionResult + return + } + + // Commit the transaction + if _, err := tx.Commit(); err != nil { + runTransactionResult := RunTransactionResult{ + err: fmt.Errorf("Commit got: %v, want: nil", err), + } + res <- runTransactionResult + return + } + + runTransactionResult := RunTransactionResult{ + runTime: time.Since(start).Seconds(), + } + res <- runTransactionResult +} + +func TestIntegration_BeginLater(t *testing.T) { + if testing.Short() { + t.Skip("Integration tests skipped in short mode") + } + ctx := context.Background() + client := newTestClient(ctx, t) + defer client.Close() + + wantAggResult := AggregationResult(map[string]interface{}{ + "count": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: 3}}, + "sum": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: 3}}, + "avg": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: 1}}, + }) + + mockErr := errors.New("Mock error") + testcases := []struct { + desc string + options []TransactionOption + hasReadOnlyOption bool + failTransaction bool + }{ + { + desc: "Failed transaction with BeginLater, MaxAttempts(2), ReadOnly options", + options: []TransactionOption{BeginLater, MaxAttempts(2), ReadOnly}, + hasReadOnlyOption: true, + failTransaction: true, + }, + { + desc: "BeginLater, MaxAttempts(2), ReadOnly", + options: []TransactionOption{BeginLater, MaxAttempts(2), ReadOnly}, + hasReadOnlyOption: true, + failTransaction: false, + }, + { + desc: "BeginLater, MaxAttempts(2)", + options: []TransactionOption{BeginLater, MaxAttempts(2)}, + hasReadOnlyOption: false, + }, + { + desc: "BeginLater, ReadOnly", + options: []TransactionOption{BeginLater, ReadOnly}, + hasReadOnlyOption: true, + }, + } + + for _, testcase := range testcases { + // Populate data + now := timeNow.Truncate(time.Millisecond).Unix() + keys, parent, cleanupData := populateData(t, client, 3, now, "BeginLater") + + testutil.Retry(t, 5, 10*time.Second, func(r *testutil.R) { + _, err := client.RunInTransaction(ctx, func(tx *Transaction) error { + query := NewQuery("SQChild").Ancestor(parent).FilterField("T", "=", now).Transaction(tx) + dst := []*SQChild{} + if _, err := client.GetAll(ctx, query, &dst); err != nil { + return err + } + + aggQuery := query.NewAggregationQuery(). + WithCount("count"). + WithSum("I", "sum"). + WithAvg("I", "avg") + gotAggResult, err := client.RunAggregationQuery(ctx, aggQuery) + if err != nil { + return err + } + if !reflect.DeepEqual(gotAggResult, wantAggResult) { + return fmt.Errorf("Mismatch in aggregation result got: %+v, want: %+v", gotAggResult, wantAggResult) + } + + if !testcase.hasReadOnlyOption { + v := &SQChild{I: 22, T: now, U: now, V: 1.5, W: "str"} + if _, err := tx.Put(keys[0], v); err != nil { + return err + } + + if err := tx.Delete(keys[1]); err != nil { + return err + } + } + if testcase.failTransaction { + // Deliberately, fail the transaction to rollback it + return mockErr + } + return nil + }, testcase.options...) + + if !testcase.failTransaction { + if err != nil { + r.Errorf("%v got: %v, want: nil", testcase.desc, err) + } + if !testcase.hasReadOnlyOption { + // Transactions are atomic. Check if Put and Delete succeeded ensuring they were run as transaction + verifyBeginLater(r, testcase.desc+" Committed Put", client, parent, now, 22, 1) + verifyBeginLater(r, testcase.desc+" Committed Delete", client, parent, now, 1, 0) + } + } else { + if err == nil { + r.Errorf("%v got: nil, want: %v", testcase.desc, mockErr) + } + if !testcase.hasReadOnlyOption { + // Transactions are atomic. Check if Put and Delete rollbacked ensuring they were run as transaction + verifyBeginLater(r, testcase.desc+" Rollbacked Put", client, parent, now, 22, 0) + verifyBeginLater(r, testcase.desc+" Rollbacked Delete", client, parent, now, 1, 1) + } + } + }) + cleanupData() + } +} + +func verifyBeginLater(r *testutil.R, errPrefix string, client *Client, parent *Key, tvalue int64, ivalue, wantDstLen int) { + ctx := context.Background() + query := NewQuery("SQChild").Ancestor(parent).FilterField("T", "=", tvalue).FilterField("I", "=", ivalue) + dst := []*SQChild{} + _, err := client.GetAll(ctx, query, &dst) + if err != nil { + r.Errorf("%v GetAll got: %v, want: nil", errPrefix, err) + } + if len(dst) != wantDstLen { + r.Errorf("%v len(dst) got: %v, want: %v", errPrefix, len(dst), wantDstLen) + } +} + func TestIntegration_AggregationQueriesInTransaction(t *testing.T) { ctx := context.Background() client := newTestClient(ctx, t) diff --git a/datastore/mock_test.go b/datastore/mock_test.go index 5a406b554226..888fd002c287 100644 --- a/datastore/mock_test.go +++ b/datastore/mock_test.go @@ -49,6 +49,10 @@ type reqItem struct { adjust func(gotReq proto.Message) } +const ( + mockProjectID = "projectID" +) + func newMock(t *testing.T) (_ *Client, _ *mockServer, _ func()) { srv, cleanup, err := newMockServer() if err != nil { @@ -59,7 +63,7 @@ func newMock(t *testing.T) (_ *Client, _ *mockServer, _ func()) { if err != nil { t.Fatal(err) } - client, err := NewClient(context.Background(), "projectID", option.WithGRPCConn(conn)) + client, err := NewClient(context.Background(), mockProjectID, option.WithGRPCConn(conn)) if err != nil { t.Fatal(err) } @@ -153,3 +157,27 @@ func (s *mockServer) Commit(_ context.Context, in *pb.CommitRequest) (*pb.Commit } return res.(*pb.CommitResponse), nil } + +func (s *mockServer) BeginTransaction(ctx context.Context, in *pb.BeginTransactionRequest) (*pb.BeginTransactionResponse, error) { + res, err := s.popRPC(in) + if err != nil { + return nil, err + } + return res.(*pb.BeginTransactionResponse), nil +} + +func (s *mockServer) RunQuery(ctx context.Context, in *pb.RunQueryRequest) (*pb.RunQueryResponse, error) { + res, err := s.popRPC(in) + if err != nil { + return nil, err + } + return res.(*pb.RunQueryResponse), nil +} + +func (s *mockServer) RunAggregationQuery(ctx context.Context, in *pb.RunAggregationQueryRequest) (*pb.RunAggregationQueryResponse, error) { + res, err := s.popRPC(in) + if err != nil { + return nil, err + } + return res.(*pb.RunAggregationQueryResponse), nil +} diff --git a/datastore/query.go b/datastore/query.go index ad5fb2994480..a949362bef2f 100644 --- a/datastore/query.go +++ b/datastore/query.go @@ -505,7 +505,7 @@ func (q *Query) toRunQueryRequest(req *pb.RunQueryRequest) error { return err } - req.ReadOptions, err = parseReadOptions(q.eventual, q.trans) + req.ReadOptions, err = parseQueryReadOptions(q.eventual, q.trans) if err != nil { return err } @@ -798,7 +798,12 @@ func (c *Client) RunAggregationQuery(ctx context.Context, aq *AggregationQuery) } // Parse the read options. - req.ReadOptions, err = parseReadOptions(aq.query.eventual, aq.query.trans) + txn := aq.query.trans + if txn != nil { + defer txn.acquireLock()() + } + + req.ReadOptions, err = parseQueryReadOptions(aq.query.eventual, txn) if err != nil { return nil, err } @@ -808,6 +813,10 @@ func (c *Client) RunAggregationQuery(ctx context.Context, aq *AggregationQuery) return nil, err } + if txn != nil && txn.state == transactionStateNotStarted { + txn.startProgress(res.Transaction) + } + ar = make(AggregationResult) // TODO(developer): change batch parsing logic if other aggregations are supported. @@ -824,26 +833,24 @@ func validateReadOptions(eventual bool, t *Transaction) error { if t == nil { return nil } - if t.id == nil { - return errExpiredTransaction - } if eventual { return errors.New("datastore: cannot use EventualConsistency query in a transaction") } + if t.state == transactionStateExpired { + return errExpiredTransaction + } return nil } -// parseReadOptions translates Query read options into protobuf format. -func parseReadOptions(eventual bool, t *Transaction) (*pb.ReadOptions, error) { +// parseQueryReadOptions translates Query read options into protobuf format. +func parseQueryReadOptions(eventual bool, t *Transaction) (*pb.ReadOptions, error) { err := validateReadOptions(eventual, t) if err != nil { return nil, err } if t != nil { - return &pb.ReadOptions{ - ConsistencyType: &pb.ReadOptions_Transaction{Transaction: t.id}, - }, nil + return t.parseReadOptions() } if eventual { @@ -884,11 +891,9 @@ type Iterator struct { entityCursor []byte // trans records the transaction in which the query was run - // Currently, this value is set but unused trans *Transaction // eventual records whether the query was eventual - // Currently, this value is set but unused eventual bool } @@ -956,12 +961,27 @@ func (t *Iterator) nextBatch() error { q.Limit = nil } + txn := t.trans + if txn != nil { + defer txn.acquireLock()() + } + + var err error + t.req.ReadOptions, err = parseQueryReadOptions(t.eventual, txn) + if err != nil { + return err + } + // Run the query. resp, err := t.client.client.RunQuery(t.ctx, t.req) if err != nil { return err } + if txn != nil && txn.state == transactionStateNotStarted { + txn.startProgress(resp.Transaction) + } + // Adjust any offset from skipped results. skip := resp.Batch.SkippedResults if skip < 0 { diff --git a/datastore/query_test.go b/datastore/query_test.go index 879da29af803..97e6fbe57ccf 100644 --- a/datastore/query_test.go +++ b/datastore/query_test.go @@ -757,9 +757,6 @@ func TestReadOptions(t *testing.T) { if err := test.q.toRunQueryRequest(req); err != nil { t.Fatalf("%+v: got %v, want no error", test.q, err) } - if got := req.ReadOptions; !proto.Equal(got, test.want) { - t.Errorf("%+v:\ngot %+v\nwant %+v", test.q, got, test.want) - } } // Test errors. for _, q := range []*Query{ @@ -889,6 +886,8 @@ func TestAggregationQueryIsNil(t *testing.T) { } func TestValidateReadOptions(t *testing.T) { + eventualInTxnErr := errors.New("datastore: cannot use EventualConsistency query in a transaction") + for _, test := range []struct { desc string eventual bool @@ -899,24 +898,27 @@ func TestValidateReadOptions(t *testing.T) { desc: "EventualConsistency query in a transaction", eventual: true, trans: &Transaction{ - id: []byte("test id"), + id: []byte("test id"), + state: transactionStateInProgress, }, - wantErr: errors.New("datastore: cannot use EventualConsistency query in a transaction"), + wantErr: eventualInTxnErr, }, { desc: "Expired transaction in non-eventual query", trans: &Transaction{ - id: nil, + id: nil, + state: transactionStateExpired, }, wantErr: errExpiredTransaction, }, { desc: "Expired transaction in eventual query", trans: &Transaction{ - id: nil, + id: nil, + state: transactionStateExpired, }, eventual: true, - wantErr: errExpiredTransaction, + wantErr: eventualInTxnErr, }, { desc: "No transaction in non-eventual query", diff --git a/datastore/transaction.go b/datastore/transaction.go index 647e50b27942..b35c6a31ce65 100644 --- a/datastore/transaction.go +++ b/datastore/transaction.go @@ -17,6 +17,7 @@ package datastore import ( "context" "errors" + "sync" "time" "cloud.google.com/go/internal/trace" @@ -42,8 +43,6 @@ type transactionSettings struct { // uses the piggybacked txn id from first read rpc call. // If there are no read operations on transaction, BeginTransaction RPC call is made // before rollback or commit - // Currently, this setting is set but unused - // TODO: b/291258189 - Use this setting beginLater bool } @@ -109,8 +108,6 @@ func (readOnly) apply(s *transactionSettings) { } // BeginLater is a TransactionOption that can be used to improve transaction performance -// Currently, it is a no-op -// TODO: b/291258189 - Add implementation var BeginLater TransactionOption type beginLater struct{} @@ -122,7 +119,7 @@ func (beginLater) apply(s *transactionSettings) { type transactionState int const ( - transactionStateNotStarted transactionState = iota // Currently unused + transactionStateNotStarted transactionState = iota transactionStateInProgress transactionStateExpired ) @@ -144,6 +141,7 @@ type Transaction struct { pending map[int]*PendingKey // Map from mutation index to incomplete keys pending transaction completion. settings *transactionSettings state transactionState + stateLock sync.Mutex } // NewTransaction starts a new transaction. @@ -159,47 +157,115 @@ func (c *Client) NewTransaction(ctx context.Context, opts ...TransactionOption) return c.newTransaction(ctx, newTransactionSettings(opts)) } -func (c *Client) newTransaction(ctx context.Context, s *transactionSettings) (_ *Transaction, err error) { - req := &pb.BeginTransactionRequest{ - ProjectId: c.dataset, - DatabaseId: c.databaseID, +func (t *Transaction) parseTransactionOptions() (*pb.TransactionOptions, string) { + if t.settings == nil { + return nil, "" } - if s.readOnly { - ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.Transaction.ReadOnlyTransaction") - defer func() { trace.EndSpan(ctx, err) }() + if t.settings.readOnly { ro := &pb.TransactionOptions_ReadOnly{} - if !s.readTime.AsTime().IsZero() { - ro.ReadTime = s.readTime + if !t.settings.readTime.AsTime().IsZero() { + ro.ReadTime = t.settings.readTime } - req.TransactionOptions = &pb.TransactionOptions{ + return &pb.TransactionOptions{ Mode: &pb.TransactionOptions_ReadOnly_{ReadOnly: ro}, - } - - } else if s.prevID != nil { - ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.Transaction.ReadWriteTransaction") - defer func() { trace.EndSpan(ctx, err) }() + }, "cloud.google.com/go/datastore.Transaction.ReadOnlyTransaction" + } - req.TransactionOptions = &pb.TransactionOptions{ + if t.settings.prevID != nil { + return &pb.TransactionOptions{ Mode: &pb.TransactionOptions_ReadWrite_{ReadWrite: &pb.TransactionOptions_ReadWrite{ - PreviousTransaction: s.prevID, + PreviousTransaction: t.settings.prevID, }}, - } + }, "cloud.google.com/go/datastore.Transaction.ReadWriteTransaction" } - resp, err := c.client.BeginTransaction(ctx, req) + return nil, "" +} + +// beginTransaction makes BeginTransaction rpc +func (t *Transaction) beginTransaction() (txnID []byte, err error) { + + req := &pb.BeginTransactionRequest{ + ProjectId: t.client.dataset, + DatabaseId: t.client.databaseID, + } + + txOptionsPb, spanName := t.parseTransactionOptions() + if txOptionsPb != nil { + t.ctx = trace.StartSpan(t.ctx, spanName) + defer func() { trace.EndSpan(t.ctx, err) }() + req.TransactionOptions = txOptionsPb + } + + resp, err := t.client.client.BeginTransaction(t.ctx, req) if err != nil { return nil, err } - return &Transaction{ - id: resp.Transaction, + return resp.Transaction, nil +} + +// beginLaterTransaction makes BeginTransaction rpc if transaction has not yet started +func (t *Transaction) beginLaterTransaction() (err error) { + if t.state != transactionStateNotStarted { + return nil + } + + // Obtain state lock since the state needs to be updated + // after transaction has started + t.stateLock.Lock() + defer t.stateLock.Unlock() + if t.state != transactionStateNotStarted { + return nil + } + + txnID, err := t.beginTransaction() + if err != nil { + return err + } + + t.startProgress(txnID) + return nil +} + +// Acquire state lock if transaction has not started +func (t *Transaction) acquireLock() func() { + if t.state == transactionStateNotStarted { + t.stateLock.Lock() + // Check whether state changed while waiting to acquire lock + if t.state == transactionStateNotStarted { + return func() { t.stateLock.Unlock() } + } + t.stateLock.Unlock() + } + return func() {} +} + +func (t *Transaction) startProgress(id []byte) { + t.id = id + t.state = transactionStateInProgress +} + +func (c *Client) newTransaction(ctx context.Context, s *transactionSettings) (_ *Transaction, err error) { + t := &Transaction{ + id: nil, ctx: ctx, client: c, mutations: nil, pending: make(map[int]*PendingKey), - state: transactionStateInProgress, settings: s, - }, nil + } + + t.state = transactionStateNotStarted + if !s.beginLater { + txnID, err := t.beginTransaction() + if err != nil { + return nil, err + } + t.startProgress(txnID) + } + + return t, nil } // RunInTransaction runs f in a transaction. f is invoked with a Transaction @@ -258,6 +324,12 @@ func (t *Transaction) Commit() (c *Commit, err error) { if t.state == transactionStateExpired { return nil, errExpiredTransaction } + + err = t.beginLaterTransaction() + if err != nil { + return nil, err + } + req := &pb.CommitRequest{ ProjectId: t.client.dataset, DatabaseId: t.client.databaseID, @@ -269,11 +341,12 @@ func (t *Transaction) Commit() (c *Commit, err error) { if status.Code(err) == codes.Aborted { return nil, ErrConcurrentTransaction } - t.state = transactionStateExpired // mark the transaction as expired if err != nil { return nil, err } + t.state = transactionStateExpired + c = &Commit{} // Copy any newly minted keys into the returned keys. for i, p := range t.pending { @@ -299,12 +372,63 @@ func (t *Transaction) Rollback() (err error) { if t.state == transactionStateExpired { return errExpiredTransaction } - t.state = transactionStateExpired + + err = t.beginLaterTransaction() + if err != nil { + return err + } + _, err = t.client.client.Rollback(t.ctx, &pb.RollbackRequest{ ProjectId: t.client.dataset, DatabaseId: t.client.databaseID, Transaction: t.id, }) + if err != nil { + return err + } + + t.state = transactionStateExpired + return nil +} + +func (t *Transaction) parseReadOptions() (*pb.ReadOptions, error) { + var opts *pb.ReadOptions + switch t.state { + case transactionStateExpired: + return nil, errExpiredTransaction + case transactionStateInProgress: + opts = &pb.ReadOptions{ + // Use existing transaction id for this request + ConsistencyType: &pb.ReadOptions_Transaction{Transaction: t.id}, + } + case transactionStateNotStarted: + tOptionsPb, _ := t.parseTransactionOptions() + opts = &pb.ReadOptions{ + // Begin a new transaction for this request + ConsistencyType: &pb.ReadOptions_NewTransaction{NewTransaction: tOptionsPb}, + } + } + return opts, nil +} + +func (t *Transaction) get(spanName string, keys []*Key, dst interface{}) (err error) { + t.ctx = trace.StartSpan(t.ctx, spanName) + defer func() { trace.EndSpan(t.ctx, err) }() + + if t != nil { + defer t.acquireLock()() + } + + opts, err := t.parseReadOptions() + if err != nil { + return err + } + + txnID, err := t.client.get(t.ctx, keys, dst, opts) + + if txnID != nil && err == nil { + t.startProgress(txnID) + } return err } @@ -314,15 +438,7 @@ func (t *Transaction) Rollback() (err error) { // level, another transaction cannot concurrently modify the data that is read // or modified by this transaction. func (t *Transaction) Get(key *Key, dst interface{}) (err error) { - t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/datastore.Transaction.Get") - defer func() { trace.EndSpan(t.ctx, err) }() - - opts := &pb.ReadOptions{ - ConsistencyType: &pb.ReadOptions_Transaction{Transaction: t.id}, - } - - // TODO: Use transaction ID returned by get - _, err = t.client.get(t.ctx, []*Key{key}, []interface{}{dst}, opts) + err = t.get("cloud.google.com/go/datastore.Transaction.Get", []*Key{key}, []interface{}{dst}) if me, ok := err.(MultiError); ok { return me[0] } @@ -331,19 +447,7 @@ func (t *Transaction) Get(key *Key, dst interface{}) (err error) { // GetMulti is a batch version of Get. func (t *Transaction) GetMulti(keys []*Key, dst interface{}) (err error) { - t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/datastore.Transaction.GetMulti") - defer func() { trace.EndSpan(t.ctx, err) }() - - if t.state == transactionStateExpired { - return errExpiredTransaction - } - opts := &pb.ReadOptions{ - ConsistencyType: &pb.ReadOptions_Transaction{Transaction: t.id}, - } - - // TODO: Use transaction ID returned by get - _, err = t.client.get(t.ctx, keys, dst, opts) - return err + return t.get("cloud.google.com/go/datastore.Transaction.GetMulti", keys, dst) } // Put is the transaction-specific version of the package function Put. diff --git a/datastore/transaction_test.go b/datastore/transaction_test.go index a6be74c5b149..c058be7f894f 100644 --- a/datastore/transaction_test.go +++ b/datastore/transaction_test.go @@ -88,3 +88,398 @@ func TestNewTransaction(t *testing.T) { } } } + +func TestBeginLaterTransactionOption(t *testing.T) { + type ent struct { + A int + } + type addRPCInput struct { + wantReq proto.Message + resp interface{} + } + + mockKind := "mockKind" + mockTxnID := []byte("tid") + mockKey := NameKey(mockKind, "testName", nil) + mockEntity := &pb.Entity{ + Key: keyToProto(mockKey), + Properties: map[string]*pb.Value{ + "A": {ValueType: &pb.Value_IntegerValue{IntegerValue: 0}}, + }, + } + mockEntityResults := []*pb.EntityResult{ + { + Entity: mockEntity, + Version: 1, + }, + } + + // Requests and responses to be used in tests + txnReadOptions := &pb.ReadOptions{ + ConsistencyType: &pb.ReadOptions_Transaction{ + Transaction: mockTxnID, + }, + } + newTxnReadOptions := &pb.ReadOptions{ + ConsistencyType: &pb.ReadOptions_NewTransaction{}, + } + + lookupReqWithTxn := &pb.LookupRequest{ + ProjectId: mockProjectID, + DatabaseId: "", + Keys: []*pb.Key{ + keyToProto(mockKey), + }, + ReadOptions: txnReadOptions, + } + lookupResWithTxn := &pb.LookupResponse{ + Found: mockEntityResults, + } + + lookupReqWithNewTxn := &pb.LookupRequest{ + ProjectId: mockProjectID, + DatabaseId: "", + Keys: []*pb.Key{ + keyToProto(mockKey), + }, + ReadOptions: newTxnReadOptions, + } + lookupResWithNewTxn := &pb.LookupResponse{ + Transaction: mockTxnID, + Found: mockEntityResults, + } + + runQueryReqWithTxn := &pb.RunQueryRequest{ + ProjectId: mockProjectID, + QueryType: &pb.RunQueryRequest_Query{Query: &pb.Query{ + Kind: []*pb.KindExpression{{Name: mockKind}}, + }}, + ReadOptions: txnReadOptions, + } + runQueryResWithTxn := &pb.RunQueryResponse{ + Batch: &pb.QueryResultBatch{ + MoreResults: pb.QueryResultBatch_NO_MORE_RESULTS, + EntityResultType: pb.EntityResult_FULL, + EntityResults: mockEntityResults, + }, + } + + runQueryReqWithNewTxn := &pb.RunQueryRequest{ + ProjectId: mockProjectID, + QueryType: &pb.RunQueryRequest_Query{Query: &pb.Query{ + Kind: []*pb.KindExpression{{Name: mockKind}}, + }}, + ReadOptions: newTxnReadOptions, + } + runQueryResWithNewTxn := &pb.RunQueryResponse{ + Transaction: mockTxnID, + Batch: &pb.QueryResultBatch{ + MoreResults: pb.QueryResultBatch_NO_MORE_RESULTS, + EntityResultType: pb.EntityResult_FULL, + EntityResults: mockEntityResults, + }, + } + + countAlias := "count" + runAggQueryReqWithTxn := &pb.RunAggregationQueryRequest{ + ProjectId: mockProjectID, + ReadOptions: txnReadOptions, + QueryType: &pb.RunAggregationQueryRequest_AggregationQuery{ + AggregationQuery: &pb.AggregationQuery{ + QueryType: &pb.AggregationQuery_NestedQuery{ + NestedQuery: &pb.Query{ + Kind: []*pb.KindExpression{{Name: mockKind}}, + }, + }, + Aggregations: []*pb.AggregationQuery_Aggregation{ + { + Operator: &pb.AggregationQuery_Aggregation_Count_{}, + Alias: countAlias, + }, + }, + }, + }, + } + runAggQueryResWithTxn := &pb.RunAggregationQueryResponse{ + Batch: &pb.AggregationResultBatch{ + AggregationResults: []*pb.AggregationResult{ + { + AggregateProperties: map[string]*pb.Value{ + countAlias: { + ValueType: &pb.Value_IntegerValue{IntegerValue: 1}, + }, + }, + }, + }, + }, + } + + runAggQueryReqWithNewTxn := &pb.RunAggregationQueryRequest{ + ProjectId: mockProjectID, + ReadOptions: newTxnReadOptions, + QueryType: &pb.RunAggregationQueryRequest_AggregationQuery{ + AggregationQuery: &pb.AggregationQuery{ + QueryType: &pb.AggregationQuery_NestedQuery{ + NestedQuery: &pb.Query{ + Kind: []*pb.KindExpression{{Name: mockKind}}, + }, + }, + Aggregations: []*pb.AggregationQuery_Aggregation{ + { + Operator: &pb.AggregationQuery_Aggregation_Count_{}, + Alias: countAlias, + }, + }, + }, + }, + } + runAggQueryResWithNewTxn := &pb.RunAggregationQueryResponse{ + Batch: &pb.AggregationResultBatch{ + AggregationResults: []*pb.AggregationResult{ + { + AggregateProperties: map[string]*pb.Value{ + countAlias: { + ValueType: &pb.Value_IntegerValue{IntegerValue: 1}, + }, + }, + }, + }, + }, + Transaction: mockTxnID, + } + + commitReq := &pb.CommitRequest{ + ProjectId: mockProjectID, + Mode: pb.CommitRequest_TRANSACTIONAL, + TransactionSelector: &pb.CommitRequest_Transaction{ + Transaction: mockTxnID, + }, + Mutations: []*pb.Mutation{ + { + Operation: &pb.Mutation_Upsert{ + Upsert: mockEntity, + }, + }, + }, + } + commitRes := &pb.CommitResponse{} + + beginTxnReq := &pb.BeginTransactionRequest{ + ProjectId: mockProjectID, + } + beginTxnRes := &pb.BeginTransactionResponse{ + Transaction: mockTxnID, + } + + testcases := []struct { + desc string + rpcInputs []addRPCInput + ops []string + settings *transactionSettings + }{ + { + desc: "[Get, Get, Put, Commit] No options. First Get does not pass new_transaction", + rpcInputs: []addRPCInput{ + { + wantReq: beginTxnReq, + resp: beginTxnRes, + }, + { + wantReq: lookupReqWithTxn, + resp: lookupResWithTxn, + }, + { + wantReq: lookupReqWithTxn, + resp: lookupResWithTxn, + }, + { + wantReq: commitReq, + resp: commitRes, + }, + }, + ops: []string{"Get", "Get", "Put", "Commit"}, + settings: &transactionSettings{}, + }, + { + desc: "[Get, Get, Put, Commit] BeginLater. First Get passes new_transaction", + rpcInputs: []addRPCInput{ + { + wantReq: lookupReqWithNewTxn, + resp: lookupResWithNewTxn, + }, + { + wantReq: lookupReqWithTxn, + resp: lookupResWithTxn, + }, + { + wantReq: commitReq, + resp: commitRes, + }, + }, + ops: []string{"Get", "Get", "Put", "Commit"}, + settings: &transactionSettings{beginLater: true}, + }, + { + desc: "[RunQuery, Get, Put, Commit] No options. RunQuery does not pass new_transaction", + rpcInputs: []addRPCInput{ + { + wantReq: beginTxnReq, + resp: beginTxnRes, + }, + { + wantReq: runQueryReqWithTxn, + resp: runQueryResWithTxn, + }, + { + wantReq: lookupReqWithTxn, + resp: lookupResWithTxn, + }, + { + wantReq: commitReq, + resp: commitRes, + }, + }, + ops: []string{"RunQuery", "Get", "Put", "Commit"}, + settings: &transactionSettings{}, + }, + { + desc: "[RunQuery, Get, Put, Commit] BeginLater. RunQuery passes new_transaction", + rpcInputs: []addRPCInput{ + { + wantReq: runQueryReqWithNewTxn, + resp: runQueryResWithNewTxn, + }, + { + wantReq: lookupReqWithTxn, + resp: lookupResWithTxn, + }, + { + wantReq: commitReq, + resp: commitRes, + }, + }, + ops: []string{"RunQuery", "Get", "Put", "Commit"}, + settings: &transactionSettings{beginLater: true}, + }, + { + desc: "[RunAggregationQuery, Get, Put, Commit] No options. RunAggregationQuery does not pass new_transaction", + rpcInputs: []addRPCInput{ + { + wantReq: beginTxnReq, + resp: beginTxnRes, + }, + { + wantReq: runAggQueryReqWithTxn, + resp: runAggQueryResWithTxn, + }, + { + wantReq: lookupReqWithTxn, + resp: lookupResWithTxn, + }, + { + wantReq: commitReq, + resp: commitRes, + }, + }, + ops: []string{"RunAggregationQuery", "Get", "Put", "Commit"}, + settings: &transactionSettings{}, + }, + { + desc: "[RunAggregationQuery, Get, Put, Commit] BeginLater. RunAggregationQuery passes new_transaction", + rpcInputs: []addRPCInput{ + { + wantReq: runAggQueryReqWithNewTxn, + resp: runAggQueryResWithNewTxn, + }, + { + wantReq: lookupReqWithTxn, + resp: lookupResWithTxn, + }, + { + wantReq: commitReq, + resp: commitRes, + }, + }, + ops: []string{"RunAggregationQuery", "Get", "Put", "Commit"}, + settings: &transactionSettings{beginLater: true}, + }, + { + desc: "[Put, Commit] No options. BeginTransaction request sent", + rpcInputs: []addRPCInput{ + { + wantReq: beginTxnReq, + resp: beginTxnRes, + }, + { + wantReq: commitReq, + resp: commitRes, + }, + }, + ops: []string{"Put", "Commit"}, + settings: &transactionSettings{}, + }, + { + desc: "[Put, Commit] BeginLater. BeginTransaction request sent", + rpcInputs: []addRPCInput{ + { + wantReq: beginTxnReq, + resp: beginTxnRes, + }, + { + wantReq: commitReq, + resp: commitRes, + }, + }, + ops: []string{"Put", "Commit"}, + settings: &transactionSettings{beginLater: true}, + }, + } + + for _, testcase := range testcases { + ctx := context.Background() + client, srv, cleanup := newMock(t) + defer cleanup() + for _, rpcInput := range testcase.rpcInputs { + srv.addRPC(rpcInput.wantReq, rpcInput.resp) + } + + dst := &ent{} + + txn, err := client.newTransaction(ctx, testcase.settings) + if err != nil { + t.Fatalf("%q: %v", testcase.desc, err) + } + + for i, op := range testcase.ops { + switch op { + case "RunQuery": + query := NewQuery(mockKind).Transaction(txn) + got := []*ent{} + if _, err := client.GetAll(ctx, query, &got); err != nil { + t.Fatalf("%q RunQuery[%v] failed with error %v", testcase.desc, i, err) + } + case "RunAggregationQuery": + aggQuery := NewQuery(mockKind).Transaction(txn).NewAggregationQuery() + aggQuery.WithCount(countAlias) + + if _, err := client.RunAggregationQuery(ctx, aggQuery); err != nil { + t.Fatalf("%q RunAggregationQuery[%v] failed with error %v", testcase.desc, i, err) + } + case "Get": + if err := txn.Get(mockKey, dst); err != nil { + t.Fatalf("%q Get[%v] failed with error %v", testcase.desc, i, err) + } + case "Put": + _, err := txn.Put(mockKey, dst) + if err != nil { + t.Fatalf("%q Put[%v] failed with error %v", testcase.desc, i, err) + } + case "Commit": + _, err := txn.Commit() + if err != nil { + t.Fatalf("%q Commit[%v] failed with error %v", testcase.desc, i, err) + } + } + } + } +}