Skip to content

Commit

Permalink
return-txn-current-if-exists
Browse files Browse the repository at this point in the history
  • Loading branch information
brucecurcio committed Jan 14, 2025
1 parent d104685 commit ad36bbd
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 161 deletions.
35 changes: 9 additions & 26 deletions gorm/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,27 +103,23 @@ func (t *Transaction) AddAfterCommitHook(hooks ...func(context.Context)) {
}

// getReadOnlyDBInstance returns the read only db txn if RO DB available otherwise it returns read/write db txn
// unless current txn already exists
func getReadOnlyDBTxn(ctx context.Context, opts *databaseOptions, txn *Transaction) (*gorm.DB, error) {
if txn.current != nil {
return txn.current, nil
}
var db *gorm.DB
switch {
case txn.parentRO == nil:
return getReadWriteDBTxn(ctx, opts, txn)
case opts.txOpts != nil && txn.currentOpts.txOpts != nil:
if *opts.txOpts != *txn.currentOpts.txOpts {
return nil, ErrCtxTxnOptMismatch
}
case opts.txOpts != nil:
// We should error in two cases 1. We should error if read-only DB requested with read-write txn
// 2. If no txn options provided in previous call but provided in subsequent call
if !opts.txOpts.ReadOnly || txn.currentOpts.database != dbNotSet {
if !opts.txOpts.ReadOnly {
return nil, ErrCtxTxnOptMismatch
}
txnOpts := *opts.txOpts
txn.currentOpts.txOpts = &txnOpts
}
if txn.current != nil {
return txn.current, nil
}
db = txn.beginReadOnlyWithContextAndOptions(ctx, txn.currentOpts.txOpts)
if db.Error != nil {
return nil, db.Error
Expand All @@ -135,26 +131,19 @@ func getReadOnlyDBTxn(ctx context.Context, opts *databaseOptions, txn *Transacti
}

// getReadWriteDBTxn returns the read/write db txn
// If current txn already exists, use it as is; opts are not applied.
func getReadWriteDBTxn(ctx context.Context, opts *databaseOptions, txn *Transaction) (*gorm.DB, error) {
if txn.current != nil {
return txn.current, nil
}
var db *gorm.DB
switch {
case txn.parent == nil:
return nil, ErrCtxTxnNoDB
case opts.txOpts != nil && txn.currentOpts.txOpts != nil:
if *opts.txOpts != *txn.currentOpts.txOpts {
return nil, ErrCtxTxnOptMismatch
}
case opts.txOpts != nil:
// We should return error If no txn options provided in previous call but provided in subsequent call
if txn.currentOpts.database != dbNotSet {
return nil, ErrCtxTxnOptMismatch
}
txnOpts := *opts.txOpts
txn.currentOpts.txOpts = &txnOpts
}
if txn.current != nil {
return txn.current, nil
}
db = txn.beginWithContextAndOptions(ctx, txn.currentOpts.txOpts)
if db.Error != nil {
return nil, db.Error
Expand All @@ -178,14 +167,8 @@ func BeginFromContext(ctx context.Context, options ...DatabaseOption) (*gorm.DB,
opts := toDatabaseOptions(options...)
switch opts.database {
case dbReadOnly:
if txn.currentOpts.database == dbReadWrite && txn.parentRO != nil {
return nil, ErrCtxDBOptMismatch
}
return getReadOnlyDBTxn(ctx, opts, txn)
case dbReadWrite:
if txn.currentOpts.database == dbReadOnly {
return nil, ErrCtxDBOptMismatch
}
return getReadWriteDBTxn(ctx, opts, txn)
default:
// This is the case to handle when no database options provided
Expand Down
160 changes: 25 additions & 135 deletions gorm/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,27 +499,17 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) {
t.Errorf("failed to begin transaction for read-write db - %s", err)
}
test.withOpts = readOnly
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxDBOptMismatch {
t.Error("begin transaction should fail with an error DBOptionsMismatch")
}
test.withOpts = readWrite
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != nil {
t.Error("Received an error beginning transaction")
}
if txn3 == nil {
if txn2 == nil {
t.Error("Did not receive a transaction from context")
}
// Case: Transaction begin is idempotent
if txn1 != txn3 {
// Case: withOpts are ignored if txn is already open, current txn is returned
if txn1 != txn2 {
t.Error("Got a different txn than was opened before")
}
test.txOpts = &sql.TxOptions{}
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxnOptionsMismatch")
}
} else {
txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != nil {
Expand All @@ -532,54 +522,16 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) {
t.Errorf("failed to begin transaction for read-write db - %s", err)
}
test.txOpts.ReadOnly = true
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxOptionsMismatch")
}
test.txOpts.ReadOnly = false
test.txOpts.Isolation = sql.LevelSerializable
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxOptionsMismatch")
}
test.txOpts.Isolation = sql.LevelDefault
test.withOpts = readOnly
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxDBOptMismatch {
t.Error("begin transaction should fail with an error DBOptionsMismatch")
}
test.withOpts = readWrite
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != nil {
t.Error("Received an error beginning transaction")
}
if txn3 == nil {
t.Error("Did not receive a transaction from context")
}
// Case: Transaction begin is idempotent
if txn1 != txn3 {
t.Error("Got a different txn than was opened before")
}
test.txOpts.ReadOnly = true
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxOptionsMismatch")
}
test.txOpts.ReadOnly = false
test.txOpts.Isolation = sql.LevelSerializable
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxOptionsMismatch")
}
txn4, err := beginFromContextWithOptions(ctx, test.withOpts, nil)
test.txOpts.Isolation = sql.LevelRepeatableRead
txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != nil {
t.Error("Received an error beginning transaction")
}
if txn4 == nil {
if txn2 == nil {
t.Error("Did not receive a transaction from context")
}
// Case: Transaction begin is idempotent
if txn1 != txn4 {
// Case: txOpts are ignored if txn is already open, current txn is returned
if txn1 != txn2 {
t.Error("Got a different txn than was opened before")
}
}
Expand Down Expand Up @@ -647,21 +599,10 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) {
if txn2 == nil {
t.Error("Did not receive a transaction from context")
}
// Case: Transaction begin is idempotent
// Case: withOpts are ignored if txn is already open, current txn is returned
if txn1 != txn2 {
t.Error("Got a different txn than was opened before")
}
test.withOpts = readWrite
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxDBOptMismatch {
t.Error("begin transaction should fail with an error DBOptionsMismatch")
}
test.withOpts = noOptions
test.txOpts = &sql.TxOptions{}
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxnOptionsMismatch")
}
} else {
_, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
Expand All @@ -678,19 +619,9 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) {
if err := dbROMock.ExpectationsWereMet(); err != nil {
t.Errorf("failed to begin transaction for read-only db - %s", err)
}
// Case: txOpts are ignored if txn is already open, current txn is returned
test.txOpts.ReadOnly = false
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxOptionsMismatch")
}
test.txOpts.ReadOnly = true
test.txOpts.Isolation = sql.LevelSerializable
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxOptionsMismatch")
}
test.txOpts.Isolation = sql.LevelDefault
test.withOpts = noOptions
txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != nil {
t.Error("Received an error beginning transaction")
Expand All @@ -701,7 +632,13 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) {
if txn1 != txn2 {
t.Error("Got a different txn than was opened before")
}
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, nil)

test.txOpts.Isolation = sql.LevelDefault
test.withOpts = noOptions

// Case: withOpts are ignored if txn is already open, current txn is returned
test.withOpts = readWrite
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != nil {
t.Error("Received an error beginning transaction")
}
Expand All @@ -711,11 +648,6 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) {
if txn1 != txn3 {
t.Error("Got a different txn than was opened before")
}
test.withOpts = readWrite
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxDBOptMismatch {
t.Error("begin transaction should fail with an error DBOptionsMismatch")
}
}
})
}
Expand Down Expand Up @@ -773,26 +705,21 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) {
t.Errorf("failed to begin transaction for read-write db - %s", err)
}
test.withOpts = readOnly
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxDBOptMismatch {
t.Error("begin transaction should fail with an error DBOptionsMismatch")
}
test.withOpts = noOptions
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != nil {
t.Error("Received an error beginning transaction")
}
if txn3 == nil {
if txn1 == nil {
t.Error("Did not receive a transaction from context")
}
// Case: Transaction begin is idempotent
if txn1 != txn3 {
// Case: withOpts are ignored if txn is already open, current txn is returned
if txn1 != txn2 {
t.Error("Got a different txn than was opened before")
}
test.txOpts = &sql.TxOptions{}
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxOptionsMismatch")
if err != nil {
t.Error("Received an error beginning transaction")
}
} else {
txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
Expand All @@ -806,55 +733,18 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) {
t.Errorf("failed to begin transaction for read-write db - %s", err)
}
test.txOpts.ReadOnly = true
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxOptionsMismatch")
}
test.txOpts.ReadOnly = false
test.txOpts.Isolation = sql.LevelSerializable
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxOptionsMismatch")
}
test.txOpts.Isolation = sql.LevelDefault
test.withOpts = readOnly
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxDBOptMismatch {
t.Error("begin transaction should fail with an error DBOptionsMismatch")
}
test.withOpts = noOptions
txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != nil {
t.Error("Received an error beginning transaction")
}
if txn2 == nil {
t.Error("Did not receive a transaction from context")
}
// Case: Transaction begin is idempotent
// Case: txOpts are ignored if txn is already open, current txn is returned
if txn1 != txn2 {
t.Error("Got a different txn than was opened before")
}
test.txOpts.ReadOnly = true
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxOptionsMismatch")
}
test.txOpts.ReadOnly = false
test.txOpts.Isolation = sql.LevelSerializable
_, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts)
if err != ErrCtxTxnOptMismatch {
t.Error("begin transaction should fail with an error TxOptionsMismatch")
}
txn3, err := beginFromContextWithOptions(ctx, test.withOpts, nil)
if err != nil {
t.Error("Received an error beginning transaction")
}
if txn3 == nil {
t.Error("Did not receive a transaction from context")
}
if txn1 != txn3 {
t.Error("Got a different txn than was opened before")
}
}
})
}
Expand Down

0 comments on commit ad36bbd

Please sign in to comment.