Skip to content

Commit

Permalink
chore(spanner): support mutation only operation for read-write mux (#…
Browse files Browse the repository at this point in the history
…11342)

* chore(spanner): support mutation only operation for read-write mux

* incorporate changes
  • Loading branch information
rahul2393 authored Jan 2, 2025
1 parent e41a153 commit 7f81daf
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 89 deletions.
60 changes: 29 additions & 31 deletions spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,19 @@ func parseDatabaseName(db string) (project, instance, database string, err error
// Client is a client for reading and writing data to a Cloud Spanner database.
// A client is safe to use concurrently, except for its Close method.
type Client struct {
sc *sessionClient
idleSessions *sessionPool
logger *log.Logger
qo QueryOptions
ro ReadOptions
ao []ApplyOption
txo TransactionOptions
bwo BatchWriteOptions
ct *commonTags
disableRouteToLeader bool
enableMultiplexedSessionForRW bool
dro *sppb.DirectedReadOptions
otConfig *openTelemetryConfig
metricsTracerFactory *builtinMetricsTracerFactory
sc *sessionClient
idleSessions *sessionPool
logger *log.Logger
qo QueryOptions
ro ReadOptions
ao []ApplyOption
txo TransactionOptions
bwo BatchWriteOptions
ct *commonTags
disableRouteToLeader bool
dro *sppb.DirectedReadOptions
otConfig *openTelemetryConfig
metricsTracerFactory *builtinMetricsTracerFactory
}

// DatabaseName returns the full name of a database, e.g.,
Expand Down Expand Up @@ -548,20 +547,19 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
}

c = &Client{
sc: sc,
idleSessions: sp,
logger: config.Logger,
qo: getQueryOptions(config.QueryOptions),
ro: config.ReadOptions,
ao: config.ApplyOptions,
txo: config.TransactionOptions,
bwo: config.BatchWriteOptions,
ct: getCommonTags(sc),
disableRouteToLeader: config.DisableRouteToLeader,
dro: config.DirectedReadOptions,
otConfig: otConfig,
metricsTracerFactory: metricsTracerFactory,
enableMultiplexedSessionForRW: config.enableMultiplexedSessionForRW,
sc: sc,
idleSessions: sp,
logger: config.Logger,
qo: getQueryOptions(config.QueryOptions),
ro: config.ReadOptions,
ao: config.ApplyOptions,
txo: config.TransactionOptions,
bwo: config.BatchWriteOptions,
ct: getCommonTags(sc),
disableRouteToLeader: config.DisableRouteToLeader,
dro: config.DirectedReadOptions,
otConfig: otConfig,
metricsTracerFactory: metricsTracerFactory,
}
return c, nil
}
Expand Down Expand Up @@ -1025,7 +1023,7 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
err error
)
if sh == nil || sh.getID() == "" || sh.getClient() == nil {
if c.enableMultiplexedSessionForRW {
if c.idleSessions.isMultiplexedSessionForRWEnabled() {
sh, err = c.idleSessions.takeMultiplexed(ctx)
} else {
// Session handle hasn't been allocated or has been destroyed.
Expand All @@ -1044,7 +1042,7 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
// Note that the t.begin(ctx) call could change the session that is being used by the transaction, as the
// BeginTransaction RPC invocation will be retried on a new session if it returns SessionNotFound.
t.txReadOnly.sh = sh
if err = t.begin(ctx); err != nil {
if err = t.begin(ctx, nil); err != nil {
trace.TracePrintf(ctx, nil, "Error while BeginTransaction during retrying a ReadWrite transaction: %v", ToSpannerError(err))
return ToSpannerError(err)
}
Expand Down Expand Up @@ -1072,7 +1070,7 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
return err
})
if isUnimplementedErrorForMultiplexedRW(err) {
c.enableMultiplexedSessionForRW = false
c.idleSessions.disableMultiplexedSessionForRW()
}
return resp, err
}
Expand Down
3 changes: 3 additions & 0 deletions spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,9 @@ func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerp
}
s.updateSessionLastUseTime(session.Name)
tx := s.beginTransaction(session, req.Options)
if session.Multiplexed && req.MutationKey != nil {
tx.PrecommitToken = s.getPreCommitToken(string(tx.Id), "TransactionPrecommitToken")
}
return tx, nil
}

Expand Down
36 changes: 32 additions & 4 deletions spanner/mutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ limitations under the License.
package spanner

import (
"math/rand"
"reflect"
"time"

sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -427,24 +429,50 @@ func (m Mutation) proto() (*sppb.Mutation, error) {

// mutationsProto turns a spanner.Mutation array into a sppb.Mutation array,
// it is convenient for sending batch mutations to Cloud Spanner.
func mutationsProto(ms []*Mutation) ([]*sppb.Mutation, error) {
func mutationsProto(ms []*Mutation) ([]*sppb.Mutation, *sppb.Mutation, error) {
var selectedMutation *Mutation
var nonInsertMutations []*Mutation

l := make([]*sppb.Mutation, 0, len(ms))
for _, m := range ms {
if m.op != opInsert {
nonInsertMutations = append(nonInsertMutations, m)
}
if selectedMutation == nil {
selectedMutation = m
}
// Track the INSERT mutation with the highest number of values if only INSERT mutation were found
if selectedMutation.op == opInsert && m.op == opInsert && len(m.values) > len(selectedMutation.values) {
selectedMutation = m
}

// Convert the mutation to sppb.Mutation and add to the list
pb, err := m.proto()
if err != nil {
return nil, err
return nil, nil, err
}
l = append(l, pb)
}
return l, nil
if len(nonInsertMutations) > 0 {
selectedMutation = nonInsertMutations[rand.New(rand.NewSource(time.Now().UnixNano())).Intn(len(nonInsertMutations))]
}
if selectedMutation != nil {
m, err := selectedMutation.proto()
if err != nil {
return nil, nil, err
}
return l, m, nil
}

return l, nil, nil
}

// mutationGroupsProto turns a spanner.MutationGroup array into a
// sppb.BatchWriteRequest_MutationGroup array, in preparation to send RPCs.
func mutationGroupsProto(mgs []*MutationGroup) ([]*sppb.BatchWriteRequest_MutationGroup, error) {
gs := make([]*sppb.BatchWriteRequest_MutationGroup, 0, len(mgs))
for _, mg := range mgs {
ms, err := mutationsProto(mg.Mutations)
ms, _, err := mutationsProto(mg.Mutations)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 7f81daf

Please sign in to comment.