Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(spanner): handle commit retry protocol extension for mux rw #11472

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 43 additions & 28 deletions spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ type SimulatedExecutionTime struct {
MinimumExecutionTime time.Duration
RandomExecutionTime time.Duration
Errors []error
Responses []interface{}
// Keep error after execution. The error will continue to be returned until
// it is cleared.
KeepError bool
Expand Down Expand Up @@ -678,47 +679,57 @@ func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, e
return result, nil
}

func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error {
func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) (interface{}, error) {
s.mu.Lock()
defer s.mu.Unlock()

// Check if the server is stopped
if s.stopped {
s.mu.Unlock()
return gstatus.Error(codes.Unavailable, "server has been stopped")
return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
}

// Send the request to the receivedRequests channel
s.receivedRequests <- req
s.mu.Unlock()
s.ready()
s.mu.Lock()

// Check for a simulated error
if s.err != nil {
err := s.err
s.err = nil
s.mu.Unlock()
return err
return nil, err
}

// Check for a simulated execution time
executionTime, ok := s.executionTimes[method]
s.mu.Unlock()
if ok {
var randTime int64
if executionTime.RandomExecutionTime > 0 {
randTime = rand.Int63n(int64(executionTime.RandomExecutionTime))
}
totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime)
<-time.After(totalExecutionTime)
s.mu.Lock()

// Check for errors in the execution time
if len(executionTime.Errors) > 0 {
err := executionTime.Errors[0]
if !executionTime.KeepError {
executionTime.Errors = executionTime.Errors[1:]
}
s.mu.Unlock()
return err
return nil, err
}

// Check for responses in the execution time
if len(executionTime.Responses) > 0 {
response := executionTime.Responses[0]
executionTime.Responses = executionTime.Responses[1:]
return response, nil
}
s.mu.Unlock()
}
return nil

return nil, nil
}

func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) {
if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil {
if _, err := s.simulateExecutionTime(MethodCreateSession, req); err != nil {
return nil, err
}
if req.Database == "" {
Expand Down Expand Up @@ -750,7 +761,7 @@ func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.C
}

func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) {
if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil {
if _, err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil {
return nil, err
}
if req.Database == "" {
Expand Down Expand Up @@ -792,7 +803,7 @@ func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spann
}

func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) {
if err := s.simulateExecutionTime(MethodGetSession, req); err != nil {
if _, err := s.simulateExecutionTime(MethodGetSession, req); err != nil {
return nil, err
}
if req.Name == "" {
Expand Down Expand Up @@ -833,7 +844,7 @@ func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.Li
}

func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) {
if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil {
if _, err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil {
return nil, err
}
if req.Name == "" {
Expand All @@ -850,7 +861,7 @@ func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.D
}

func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil {
if _, err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil {
return nil, err
}
if req.Sql == "SELECT 1" {
Expand Down Expand Up @@ -905,7 +916,7 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec
}

func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil {
if _, err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil {
return err
}
return s.executeStreamingSQL(req, stream)
Expand Down Expand Up @@ -984,7 +995,7 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques
}

func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) {
if err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil {
if _, err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil {
return nil, err
}
if req.Session == "" {
Expand Down Expand Up @@ -1047,7 +1058,7 @@ func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadReques
}

func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error {
if err := s.simulateExecutionTime(MethodStreamingRead, req); err != nil {
if _, err := s.simulateExecutionTime(MethodStreamingRead, req); err != nil {
return err
}
sqlReq := &spannerpb.ExecuteSqlRequest{
Expand All @@ -1066,7 +1077,7 @@ func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream sp
}

func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil {
if _, err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil {
return nil, err
}
if req.Session == "" {
Expand All @@ -1085,7 +1096,8 @@ func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerp
}

func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) {
if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil {
mockResponse, err := s.simulateExecutionTime(MethodCommitTransaction, req)
if err != nil {
return nil, err
}
if req.Session == "" {
Expand All @@ -1107,8 +1119,11 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe
} else {
return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request")
}
s.removeTransaction(tx)
resp := &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}
resp, ok := mockResponse.(*spannerpb.CommitResponse)
if !ok {
resp = &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}
s.removeTransaction(tx)
}
if req.ReturnCommitStats {
resp.CommitStats = &spannerpb.CommitResponse_CommitStats{
MutationCount: int64(1),
Expand Down Expand Up @@ -1142,7 +1157,7 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba
}

func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) {
if err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil {
if _, err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil {
return nil, err
}
s.mu.Lock()
Expand Down Expand Up @@ -1214,7 +1229,7 @@ func DecodeResumeToken(t []byte) (uint64, error) {
}

func (s *inMemSpannerServer) BatchWrite(req *spannerpb.BatchWriteRequest, stream spannerpb.Spanner_BatchWriteServer) error {
if err := s.simulateExecutionTime(MethodBatchWrite, req); err != nil {
if _, err := s.simulateExecutionTime(MethodBatchWrite, req); err != nil {
return err
}
return s.batchWrite(req, stream)
Expand Down
40 changes: 27 additions & 13 deletions spanner/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -1703,17 +1703,31 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions
if options.MaxCommitDelay != nil {
maxCommitDelay = durationpb.New(*(options.MaxCommitDelay))
}
res, e := client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata(), t.disableRouteToLeader), &sppb.CommitRequest{
Session: sid,
Transaction: &sppb.CommitRequest_TransactionId{
TransactionId: t.tx,
},
PrecommitToken: precommitToken,
RequestOptions: createRequestOptions(t.txOpts.CommitPriority, "", t.txOpts.TransactionTag),
Mutations: mutationProtos,
ReturnCommitStats: options.ReturnCommitStats,
MaxCommitDelay: maxCommitDelay,
}, gax.WithGRPCOptions(grpc.Header(&md)))
performCommit := func(token *sppb.MultiplexedSessionPrecommitToken, includeMutations bool) (*sppb.CommitResponse, error) {
req := &sppb.CommitRequest{
Session: sid,
Transaction: &sppb.CommitRequest_TransactionId{
TransactionId: t.tx,
},
PrecommitToken: token,
RequestOptions: createRequestOptions(t.txOpts.CommitPriority, "", t.txOpts.TransactionTag),
ReturnCommitStats: options.ReturnCommitStats,
MaxCommitDelay: maxCommitDelay,
}
if includeMutations {
req.Mutations = mutationProtos
}
return client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md)))
}
// Initial commit attempt with mutations
res, err := performCommit(precommitToken, true)
if err != nil {
return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(err, true))
}
// Retry if MultiplexedSessionRetry is present, without mutations
if res.GetMultiplexedSessionRetry() != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Here can we get the precommit token, pass it to setPrecommitToken method and then below use t.precommittoken?
I understand that the precommittoken reveived in res should be latest but here we will need to rely on sequence number to determine that the precommit token is indeed latest. This will save us in future in cases where the precommit token received in res could be outdated.

res, err = performCommit(res.GetPrecommitToken(), false)
}
if getGFELatencyMetricsFlag() && md != nil && t.ct != nil {
if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "commit"); err != nil {
trace.TracePrintf(ctx, nil, "Error in recording GFE Latency. Try disabling and rerunning. Error: %v", err)
Expand All @@ -1722,8 +1736,8 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions
if metricErr := recordGFELatencyMetricsOT(ctx, md, "commit", t.otConfig); metricErr != nil {
trace.TracePrintf(ctx, nil, "Error in recording GFE Latency through OpenTelemetry. Error: %v", metricErr)
}
if e != nil {
return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(e, true))
if err != nil {
return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(err, true))
}
if tstamp := res.GetCommitTimestamp(); tstamp != nil {
resp.CommitTs = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
Expand Down
95 changes: 95 additions & 0 deletions spanner/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,101 @@ func TestReadWriteTransaction_PrecommitToken(t *testing.T) {
}
}

func TestCommitWithMultiplexedSessionRetry(t *testing.T) {
ctx := context.Background()
server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
DisableNativeMetrics: true,
SessionPoolConfig: SessionPoolConfig{
MinOpened: 1,
MaxOpened: 1,
enableMultiplexSession: true,
enableMultiplexedSessionForRW: true,
},
})
defer teardown()

// newCommitResponseWithPrecommitToken creates a simulated response with a PrecommitToken
newCommitResponseWithPrecommitToken := func() *sppb.CommitResponse {
precommitToken := &sppb.MultiplexedSessionPrecommitToken{
PrecommitToken: []byte("commit-retry-precommit-token"),
}

// Create a CommitResponse with the PrecommitToken
return &sppb.CommitResponse{
MultiplexedSessionRetry: &sppb.CommitResponse_PrecommitToken{PrecommitToken: precommitToken},
}
}

// Simulate a commit response with a MultiplexedSessionRetry
server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
SimulatedExecutionTime{
Responses: []interface{}{newCommitResponseWithPrecommitToken()},
})

_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
ms := []*Mutation{
Insert("t_foo", []string{"col1", "col2"}, []interface{}{int64(1), int64(2)}),
Update("t_foo", []string{"col1", "col2"}, []interface{}{"one", []byte(nil)}),
}
if err := tx.BufferWrite(ms); err != nil {
return err
}

iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
defer iter.Stop()
for {
_, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
}

if _, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo}); err != nil {
return err
Comment on lines +559 to +560
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commit retry only happens in cases of DQL+Mutation. Lets remove DML to avoid confusion

}
return nil
})
if err != nil {
t.Fatalf("Commit failed: %v", err)
}

// Verify that the commit was retried
requests := drainRequestsFromServer(server.TestSpanner)
commitCount := 0
for _, req := range requests {
if commitReq, ok := req.(*sppb.CommitRequest); ok {
if !strings.Contains(commitReq.GetSession(), "multiplexed") {
t.Errorf("Expected session to be multiplexed")
}
commitCount++
if commitCount == 1 {
// Validate that the first commit had mutations set
if len(commitReq.Mutations) == 0 {
t.Fatalf("Expected first commit to have mutations set")
}
if commitReq.PrecommitToken == nil || !strings.Contains(string(commitReq.PrecommitToken.PrecommitToken), "ResultSetPrecommitToken") {
t.Fatalf("Expected first commit to have precommit token 'ResultSetPrecommitToken', got: %v", commitReq.PrecommitToken)
}
} else if commitCount == 2 {
// Validate that the second commit attempt had mutations un-set
if len(commitReq.Mutations) != 0 {
t.Fatalf("Expected second commit to have no mutations set")
}
// Validate that the second commit had the precommit token set
if commitReq.PrecommitToken == nil || string(commitReq.PrecommitToken.PrecommitToken) != "commit-retry-precommit-token" {
t.Fatalf("Expected second commit to have precommit token 'commit-retry-precommit-token', got: %v", commitReq.PrecommitToken)
}
}
}
}
if commitCount != 2 {
t.Fatalf("Expected 2 commit attempts, got %d", commitCount)
}
}

func TestMutationOnlyCaseAborted(t *testing.T) {
t.Parallel()
ctx := context.Background()
Expand Down
Loading