Skip to content

Commit

Permalink
chore(spanner): handle commit retry protocol extension for mux rw
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul2393 committed Jan 20, 2025
1 parent aa54375 commit 3dba7b6
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 41 deletions.
71 changes: 43 additions & 28 deletions spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ type SimulatedExecutionTime struct {
MinimumExecutionTime time.Duration
RandomExecutionTime time.Duration
Errors []error
Responses []interface{}
// Keep error after execution. The error will continue to be returned until
// it is cleared.
KeepError bool
Expand Down Expand Up @@ -678,47 +679,57 @@ func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, e
return result, nil
}

func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error {
func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) (interface{}, error) {
s.mu.Lock()
defer s.mu.Unlock()

// Check if the server is stopped
if s.stopped {
s.mu.Unlock()
return gstatus.Error(codes.Unavailable, "server has been stopped")
return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
}

// Send the request to the receivedRequests channel
s.receivedRequests <- req
s.mu.Unlock()
s.ready()
s.mu.Lock()

// Check for a simulated error
if s.err != nil {
err := s.err
s.err = nil
s.mu.Unlock()
return err
return nil, err
}

// Check for a simulated execution time
executionTime, ok := s.executionTimes[method]
s.mu.Unlock()
if ok {
var randTime int64
if executionTime.RandomExecutionTime > 0 {
randTime = rand.Int63n(int64(executionTime.RandomExecutionTime))
}
totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime)
<-time.After(totalExecutionTime)
s.mu.Lock()

// Check for errors in the execution time
if len(executionTime.Errors) > 0 {
err := executionTime.Errors[0]
if !executionTime.KeepError {
executionTime.Errors = executionTime.Errors[1:]
}
s.mu.Unlock()
return err
return nil, err
}

// Check for responses in the execution time
if len(executionTime.Responses) > 0 {
response := executionTime.Responses[0]
executionTime.Responses = executionTime.Responses[1:]
return response, nil
}
s.mu.Unlock()
}
return nil

return nil, nil
}

func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) {
if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil {
if _, err := s.simulateExecutionTime(MethodCreateSession, req); err != nil {
return nil, err
}
if req.Database == "" {
Expand Down Expand Up @@ -750,7 +761,7 @@ func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.C
}

func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) {
if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil {
if _, err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil {
return nil, err
}
if req.Database == "" {
Expand Down Expand Up @@ -792,7 +803,7 @@ func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spann
}

func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) {
if err := s.simulateExecutionTime(MethodGetSession, req); err != nil {
if _, err := s.simulateExecutionTime(MethodGetSession, req); err != nil {
return nil, err
}
if req.Name == "" {
Expand Down Expand Up @@ -833,7 +844,7 @@ func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.Li
}

func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) {
if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil {
if _, err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil {
return nil, err
}
if req.Name == "" {
Expand All @@ -850,7 +861,7 @@ func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.D
}

func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil {
if _, err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil {
return nil, err
}
if req.Sql == "SELECT 1" {
Expand Down Expand Up @@ -905,7 +916,7 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec
}

func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil {
if _, err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil {
return err
}
return s.executeStreamingSQL(req, stream)
Expand Down Expand Up @@ -984,7 +995,7 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques
}

func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) {
if err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil {
if _, err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil {
return nil, err
}
if req.Session == "" {
Expand Down Expand Up @@ -1047,7 +1058,7 @@ func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadReques
}

func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error {
if err := s.simulateExecutionTime(MethodStreamingRead, req); err != nil {
if _, err := s.simulateExecutionTime(MethodStreamingRead, req); err != nil {
return err
}
sqlReq := &spannerpb.ExecuteSqlRequest{
Expand All @@ -1066,7 +1077,7 @@ func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream sp
}

func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil {
if _, err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil {
return nil, err
}
if req.Session == "" {
Expand All @@ -1085,7 +1096,8 @@ func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerp
}

func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) {
if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil {
mockResponse, err := s.simulateExecutionTime(MethodCommitTransaction, req)
if err != nil {
return nil, err
}
if req.Session == "" {
Expand All @@ -1107,8 +1119,11 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe
} else {
return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request")
}
s.removeTransaction(tx)
resp := &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}
resp, ok := mockResponse.(*spannerpb.CommitResponse)
if !ok {
resp = &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}
s.removeTransaction(tx)
}
if req.ReturnCommitStats {
resp.CommitStats = &spannerpb.CommitResponse_CommitStats{
MutationCount: int64(1),
Expand Down Expand Up @@ -1142,7 +1157,7 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba
}

func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) {
if err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil {
if _, err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil {
return nil, err
}
s.mu.Lock()
Expand Down Expand Up @@ -1214,7 +1229,7 @@ func DecodeResumeToken(t []byte) (uint64, error) {
}

func (s *inMemSpannerServer) BatchWrite(req *spannerpb.BatchWriteRequest, stream spannerpb.Spanner_BatchWriteServer) error {
if err := s.simulateExecutionTime(MethodBatchWrite, req); err != nil {
if _, err := s.simulateExecutionTime(MethodBatchWrite, req); err != nil {
return err
}
return s.batchWrite(req, stream)
Expand Down
40 changes: 27 additions & 13 deletions spanner/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -1703,17 +1703,31 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions
if options.MaxCommitDelay != nil {
maxCommitDelay = durationpb.New(*(options.MaxCommitDelay))
}
res, e := client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata(), t.disableRouteToLeader), &sppb.CommitRequest{
Session: sid,
Transaction: &sppb.CommitRequest_TransactionId{
TransactionId: t.tx,
},
PrecommitToken: precommitToken,
RequestOptions: createRequestOptions(t.txOpts.CommitPriority, "", t.txOpts.TransactionTag),
Mutations: mutationProtos,
ReturnCommitStats: options.ReturnCommitStats,
MaxCommitDelay: maxCommitDelay,
}, gax.WithGRPCOptions(grpc.Header(&md)))
performCommit := func(token *sppb.MultiplexedSessionPrecommitToken, includeMutations bool) (*sppb.CommitResponse, error) {
req := &sppb.CommitRequest{
Session: sid,
Transaction: &sppb.CommitRequest_TransactionId{
TransactionId: t.tx,
},
PrecommitToken: token,
RequestOptions: createRequestOptions(t.txOpts.CommitPriority, "", t.txOpts.TransactionTag),
ReturnCommitStats: options.ReturnCommitStats,
MaxCommitDelay: maxCommitDelay,
}
if includeMutations {
req.Mutations = mutationProtos
}
return client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md)))
}
// Initial commit attempt with mutations
res, err := performCommit(precommitToken, true)
if err != nil {
return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(err, true))
}
// Retry if MultiplexedSessionRetry is present, without mutations
if res.GetMultiplexedSessionRetry() != nil {
res, err = performCommit(res.GetPrecommitToken(), false)
}
if getGFELatencyMetricsFlag() && md != nil && t.ct != nil {
if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "commit"); err != nil {
trace.TracePrintf(ctx, nil, "Error in recording GFE Latency. Try disabling and rerunning. Error: %v", err)
Expand All @@ -1722,8 +1736,8 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions
if metricErr := recordGFELatencyMetricsOT(ctx, md, "commit", t.otConfig); metricErr != nil {
trace.TracePrintf(ctx, nil, "Error in recording GFE Latency through OpenTelemetry. Error: %v", metricErr)
}
if e != nil {
return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(e, true))
if err != nil {
return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(err, true))
}
if tstamp := res.GetCommitTimestamp(); tstamp != nil {
resp.CommitTs = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))
Expand Down
95 changes: 95 additions & 0 deletions spanner/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,101 @@ func TestReadWriteTransaction_PrecommitToken(t *testing.T) {
}
}

func TestCommitWithMultiplexedSessionRetry(t *testing.T) {
ctx := context.Background()
server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
DisableNativeMetrics: true,
SessionPoolConfig: SessionPoolConfig{
MinOpened: 1,
MaxOpened: 1,
enableMultiplexSession: true,
enableMultiplexedSessionForRW: true,
},
})
defer teardown()

// newCommitResponseWithPrecommitToken creates a simulated response with a PrecommitToken
newCommitResponseWithPrecommitToken := func() *sppb.CommitResponse {
precommitToken := &sppb.MultiplexedSessionPrecommitToken{
PrecommitToken: []byte("commit-retry-precommit-token"),
}

// Create a CommitResponse with the PrecommitToken
return &sppb.CommitResponse{
MultiplexedSessionRetry: &sppb.CommitResponse_PrecommitToken{PrecommitToken: precommitToken},
}
}

// Simulate a commit response with a MultiplexedSessionRetry
server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
SimulatedExecutionTime{
Responses: []interface{}{newCommitResponseWithPrecommitToken()},
})

_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
ms := []*Mutation{
Insert("t_foo", []string{"col1", "col2"}, []interface{}{int64(1), int64(2)}),
Update("t_foo", []string{"col1", "col2"}, []interface{}{"one", []byte(nil)}),
}
if err := tx.BufferWrite(ms); err != nil {
return err
}

iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
defer iter.Stop()
for {
_, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
}

if _, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo}); err != nil {
return err
}
return nil
})
if err != nil {
t.Fatalf("Commit failed: %v", err)
}

// Verify that the commit was retried
requests := drainRequestsFromServer(server.TestSpanner)
commitCount := 0
for _, req := range requests {
if commitReq, ok := req.(*sppb.CommitRequest); ok {
if !strings.Contains(commitReq.GetSession(), "multiplexed") {
t.Errorf("Expected session to be multiplexed")
}
commitCount++
if commitCount == 1 {
// Validate that the first commit had mutations set
if len(commitReq.Mutations) == 0 {
t.Fatalf("Expected first commit to have mutations set")
}
if commitReq.PrecommitToken == nil || !strings.Contains(string(commitReq.PrecommitToken.PrecommitToken), "ResultSetPrecommitToken") {
t.Fatalf("Expected first commit to have precommit token 'ResultSetPrecommitToken', got: %v", commitReq.PrecommitToken)
}
} else if commitCount == 2 {
// Validate that the second commit attempt had mutations un-set
if len(commitReq.Mutations) != 0 {
t.Fatalf("Expected second commit to have no mutations set")
}
// Validate that the second commit had the precommit token set
if commitReq.PrecommitToken == nil || string(commitReq.PrecommitToken.PrecommitToken) != "commit-retry-precommit-token" {
t.Fatalf("Expected second commit to have precommit token 'commit-retry-precommit-token', got: %v", commitReq.PrecommitToken)
}
}
}
}
if commitCount != 2 {
t.Fatalf("Expected 2 commit attempts, got %d", commitCount)
}
}

func TestMutationOnlyCaseAborted(t *testing.T) {
t.Parallel()
ctx := context.Background()
Expand Down

0 comments on commit 3dba7b6

Please sign in to comment.