diff --git a/adapter.go b/adapter.go index d101e7a..0cea971 100755 --- a/adapter.go +++ b/adapter.go @@ -22,7 +22,6 @@ import ( "runtime" "strings" "sync" - "sync/atomic" "github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2/model" @@ -84,7 +83,7 @@ type Adapter struct { db *gorm.DB isFiltered bool transactionMu *sync.Mutex - muInitialize atomic.Bool + muInitialize sync.Once } // finalizer is the destructor for Adapter. @@ -200,8 +199,9 @@ func NewAdapterByDBUseTableName(db *gorm.DB, prefix string, tableName string) (* } a := &Adapter{ - tablePrefix: prefix, - tableName: tableName, + tablePrefix: prefix, + tableName: tableName, + transactionMu: &sync.Mutex{}, } a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context}) @@ -259,9 +259,10 @@ func NewFilteredAdapter(driverName string, dataSourceName string, params ...inte // Casbin will not automatically call LoadPolicy() for a filtered adapter. func NewFilteredAdapterByDB(db *gorm.DB, prefix string, tableName string) (*Adapter, error) { adapter := &Adapter{ - tablePrefix: prefix, - tableName: tableName, - isFiltered: true, + tablePrefix: prefix, + tableName: tableName, + isFiltered: true, + transactionMu: &sync.Mutex{}, } adapter.db = db.Scopes(adapter.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context}) @@ -692,12 +693,11 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error func (a *Adapter) Transaction(e casbin.IEnforcer, fc func(casbin.IEnforcer) error, opts ...*sql.TxOptions) error { // ensure the transactionMu is initialized if a.transactionMu == nil { - for a.muInitialize.CompareAndSwap(false, true) { + a.muInitialize.Do(func() { if a.transactionMu == nil { a.transactionMu = &sync.Mutex{} } - a.muInitialize.Store(false) - } + }) } // lock the transactionMu to ensure the transaction is thread-safe a.transactionMu.Lock() @@ -706,7 +706,7 @@ func (a *Adapter) Transaction(e casbin.IEnforcer, fc func(casbin.IEnforcer) erro oriAdapter := a.db // reload policy from database to sync with the transaction defer func() { - e.SetAdapter(&Adapter{db: oriAdapter}) + e.SetAdapter(&Adapter{db: oriAdapter, transactionMu: a.transactionMu}) err = e.LoadPolicy() if err != nil { panic(err) @@ -714,7 +714,7 @@ func (a *Adapter) Transaction(e casbin.IEnforcer, fc func(casbin.IEnforcer) erro }() copyDB := *a.db tx := copyDB.Begin(opts...) - b := &Adapter{db: tx} + b := &Adapter{db: tx, transactionMu: a.transactionMu} // copy enforcer to set the new adapter with transaction tx copyEnforcer := e copyEnforcer.SetAdapter(b)