diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 2955b182a767..8d43e65bef7e 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -250,6 +250,7 @@ type SimulatedExecutionTime struct { MinimumExecutionTime time.Duration RandomExecutionTime time.Duration Errors []error + Responses []interface{} // Keep error after execution. The error will continue to be returned until // it is cleared. KeepError bool @@ -678,24 +679,27 @@ func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, e return result, nil } -func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error { +func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) (interface{}, error) { s.mu.Lock() + defer s.mu.Unlock() + + // Check if the server is stopped if s.stopped { - s.mu.Unlock() - return gstatus.Error(codes.Unavailable, "server has been stopped") + return nil, gstatus.Error(codes.Unavailable, "server has been stopped") } + + // Send the request to the receivedRequests channel s.receivedRequests <- req - s.mu.Unlock() - s.ready() - s.mu.Lock() + + // Check for a simulated error if s.err != nil { err := s.err s.err = nil - s.mu.Unlock() - return err + return nil, err } + + // Check for a simulated execution time executionTime, ok := s.executionTimes[method] - s.mu.Unlock() if ok { var randTime int64 if executionTime.RandomExecutionTime > 0 { @@ -703,22 +707,29 @@ func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{ } totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime) <-time.After(totalExecutionTime) - s.mu.Lock() + + // Check for errors in the execution time if len(executionTime.Errors) > 0 { err := executionTime.Errors[0] if !executionTime.KeepError { executionTime.Errors = executionTime.Errors[1:] } - s.mu.Unlock() - return err + return nil, err + } + + // Check for responses in the execution time + if len(executionTime.Responses) > 0 { + response := executionTime.Responses[0] + executionTime.Responses = executionTime.Responses[1:] + return response, nil } - s.mu.Unlock() } - return nil + + return nil, nil } func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { - if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil { + if _, err := s.simulateExecutionTime(MethodCreateSession, req); err != nil { return nil, err } if req.Database == "" { @@ -750,7 +761,7 @@ func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.C } func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) { - if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil { + if _, err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil { return nil, err } if req.Database == "" { @@ -792,7 +803,7 @@ func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spann } func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { - if err := s.simulateExecutionTime(MethodGetSession, req); err != nil { + if _, err := s.simulateExecutionTime(MethodGetSession, req); err != nil { return nil, err } if req.Name == "" { @@ -833,7 +844,7 @@ func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.Li } func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) { - if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil { + if _, err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil { return nil, err } if req.Name == "" { @@ -850,7 +861,7 @@ func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.D } func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { - if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil { + if _, err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil { return nil, err } if req.Sql == "SELECT 1" { @@ -905,7 +916,7 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec } func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { - if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil { + if _, err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil { return err } return s.executeStreamingSQL(req, stream) @@ -984,7 +995,7 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques } func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { - if err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil { + if _, err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil { return nil, err } if req.Session == "" { @@ -1047,7 +1058,7 @@ func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadReques } func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { - if err := s.simulateExecutionTime(MethodStreamingRead, req); err != nil { + if _, err := s.simulateExecutionTime(MethodStreamingRead, req); err != nil { return err } sqlReq := &spannerpb.ExecuteSqlRequest{ @@ -1066,7 +1077,7 @@ func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream sp } func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { - if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil { + if _, err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil { return nil, err } if req.Session == "" { @@ -1085,7 +1096,8 @@ func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerp } func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) { - if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil { + mockResponse, err := s.simulateExecutionTime(MethodCommitTransaction, req) + if err != nil { return nil, err } if req.Session == "" { @@ -1107,8 +1119,11 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe } else { return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request") } - s.removeTransaction(tx) - resp := &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()} + resp, ok := mockResponse.(*spannerpb.CommitResponse) + if !ok { + resp = &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()} + s.removeTransaction(tx) + } if req.ReturnCommitStats { resp.CommitStats = &spannerpb.CommitResponse_CommitStats{ MutationCount: int64(1), @@ -1142,7 +1157,7 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba } func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { - if err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil { + if _, err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil { return nil, err } s.mu.Lock() @@ -1214,7 +1229,7 @@ func DecodeResumeToken(t []byte) (uint64, error) { } func (s *inMemSpannerServer) BatchWrite(req *spannerpb.BatchWriteRequest, stream spannerpb.Spanner_BatchWriteServer) error { - if err := s.simulateExecutionTime(MethodBatchWrite, req); err != nil { + if _, err := s.simulateExecutionTime(MethodBatchWrite, req); err != nil { return err } return s.batchWrite(req, stream) diff --git a/spanner/transaction.go b/spanner/transaction.go index be6ee4af7228..d9b93187328d 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -1703,17 +1703,31 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions if options.MaxCommitDelay != nil { maxCommitDelay = durationpb.New(*(options.MaxCommitDelay)) } - res, e := client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata(), t.disableRouteToLeader), &sppb.CommitRequest{ - Session: sid, - Transaction: &sppb.CommitRequest_TransactionId{ - TransactionId: t.tx, - }, - PrecommitToken: precommitToken, - RequestOptions: createRequestOptions(t.txOpts.CommitPriority, "", t.txOpts.TransactionTag), - Mutations: mutationProtos, - ReturnCommitStats: options.ReturnCommitStats, - MaxCommitDelay: maxCommitDelay, - }, gax.WithGRPCOptions(grpc.Header(&md))) + performCommit := func(token *sppb.MultiplexedSessionPrecommitToken, includeMutations bool) (*sppb.CommitResponse, error) { + req := &sppb.CommitRequest{ + Session: sid, + Transaction: &sppb.CommitRequest_TransactionId{ + TransactionId: t.tx, + }, + PrecommitToken: token, + RequestOptions: createRequestOptions(t.txOpts.CommitPriority, "", t.txOpts.TransactionTag), + ReturnCommitStats: options.ReturnCommitStats, + MaxCommitDelay: maxCommitDelay, + } + if includeMutations { + req.Mutations = mutationProtos + } + return client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md))) + } + // Initial commit attempt with mutations + res, err := performCommit(precommitToken, true) + if err != nil { + return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(err, true)) + } + // Retry if MultiplexedSessionRetry is present, without mutations + if res.GetMultiplexedSessionRetry() != nil { + res, err = performCommit(res.GetPrecommitToken(), false) + } if getGFELatencyMetricsFlag() && md != nil && t.ct != nil { if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "commit"); err != nil { trace.TracePrintf(ctx, nil, "Error in recording GFE Latency. Try disabling and rerunning. Error: %v", err) @@ -1722,8 +1736,8 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions if metricErr := recordGFELatencyMetricsOT(ctx, md, "commit", t.otConfig); metricErr != nil { trace.TracePrintf(ctx, nil, "Error in recording GFE Latency through OpenTelemetry. Error: %v", metricErr) } - if e != nil { - return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(e, true)) + if err != nil { + return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(err, true)) } if tstamp := res.GetCommitTimestamp(); tstamp != nil { resp.CommitTs = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 5617951a9696..9a2eb104f110 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -504,6 +504,101 @@ func TestReadWriteTransaction_PrecommitToken(t *testing.T) { } } +func TestCommitWithMultiplexedSessionRetry(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + DisableNativeMetrics: true, + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + MaxOpened: 1, + enableMultiplexSession: true, + enableMultiplexedSessionForRW: true, + }, + }) + defer teardown() + + // newCommitResponseWithPrecommitToken creates a simulated response with a PrecommitToken + newCommitResponseWithPrecommitToken := func() *sppb.CommitResponse { + precommitToken := &sppb.MultiplexedSessionPrecommitToken{ + PrecommitToken: []byte("commit-retry-precommit-token"), + } + + // Create a CommitResponse with the PrecommitToken + return &sppb.CommitResponse{ + MultiplexedSessionRetry: &sppb.CommitResponse_PrecommitToken{PrecommitToken: precommitToken}, + } + } + + // Simulate a commit response with a MultiplexedSessionRetry + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, + SimulatedExecutionTime{ + Responses: []interface{}{newCommitResponseWithPrecommitToken()}, + }) + + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + ms := []*Mutation{ + Insert("t_foo", []string{"col1", "col2"}, []interface{}{int64(1), int64(2)}), + Update("t_foo", []string{"col1", "col2"}, []interface{}{"one", []byte(nil)}), + } + if err := tx.BufferWrite(ms); err != nil { + return err + } + + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + + if _, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo}); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatalf("Commit failed: %v", err) + } + + // Verify that the commit was retried + requests := drainRequestsFromServer(server.TestSpanner) + commitCount := 0 + for _, req := range requests { + if commitReq, ok := req.(*sppb.CommitRequest); ok { + if !strings.Contains(commitReq.GetSession(), "multiplexed") { + t.Errorf("Expected session to be multiplexed") + } + commitCount++ + if commitCount == 1 { + // Validate that the first commit had mutations set + if len(commitReq.Mutations) == 0 { + t.Fatalf("Expected first commit to have mutations set") + } + if commitReq.PrecommitToken == nil || !strings.Contains(string(commitReq.PrecommitToken.PrecommitToken), "ResultSetPrecommitToken") { + t.Fatalf("Expected first commit to have precommit token 'ResultSetPrecommitToken', got: %v", commitReq.PrecommitToken) + } + } else if commitCount == 2 { + // Validate that the second commit attempt had mutations un-set + if len(commitReq.Mutations) != 0 { + t.Fatalf("Expected second commit to have no mutations set") + } + // Validate that the second commit had the precommit token set + if commitReq.PrecommitToken == nil || string(commitReq.PrecommitToken.PrecommitToken) != "commit-retry-precommit-token" { + t.Fatalf("Expected second commit to have precommit token 'commit-retry-precommit-token', got: %v", commitReq.PrecommitToken) + } + } + } + } + if commitCount != 2 { + t.Fatalf("Expected 2 commit attempts, got %d", commitCount) + } +} + func TestMutationOnlyCaseAborted(t *testing.T) { t.Parallel() ctx := context.Background()